From f32b2bc4e30e7dc1b195a48c1517de89f7e44f24 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 22 Dec 2025 17:40:52 -0800 Subject: [PATCH 1/3] detect mla Signed-off-by: yiliu30 --- auto_round/experimental/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/auto_round/experimental/utils.py b/auto_round/experimental/utils.py index 545ebd401..4af584549 100644 --- a/auto_round/experimental/utils.py +++ b/auto_round/experimental/utils.py @@ -71,7 +71,10 @@ def normalize_static_kv_dtype(static_kv_dtype: str | torch.dtype) -> torch.dtype def is_attention_module(module: torch.nn.Module): # FIXME: Handle this better. return "attention" in module.__class__.__name__.lower() and ( - hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj") + hasattr(module, "k_proj") + or hasattr(module, "v_proj") + or hasattr(module, "qkv_proj") + or hasattr(module, "kv_b_proj") # for DeepSpeed ) From 2ce96805e0a04b797109add8505fadeb7a9e207c Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 22 Dec 2025 17:51:06 -0800 Subject: [PATCH 2/3] fix ds load Signed-off-by: yiliu30 --- auto_round_extension/vllm_ext/__init__.py | 3 +- .../vllm_ext/vllm_oot_patches.py | 103 ++++++++++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) create mode 100644 auto_round_extension/vllm_ext/vllm_oot_patches.py diff --git a/auto_round_extension/vllm_ext/__init__.py b/auto_round_extension/vllm_ext/__init__.py index 748e917e2..7dbb55c6c 100644 --- a/auto_round_extension/vllm_ext/__init__.py +++ b/auto_round_extension/vllm_ext/__init__.py @@ -20,7 +20,8 @@ def apply(): import auto_round_extension.vllm_ext.auto_round_ext import auto_round_extension.vllm_ext.envs_ext - + import auto_round_extension.vllm_ext.vllm_oot_patches print("*****************************************************************************") print("* !!! VLLM_ENABLE_AR_EXT is set to 1, applying auto_round_vllm_extension *") print("*****************************************************************************") + diff --git a/auto_round_extension/vllm_ext/vllm_oot_patches.py b/auto_round_extension/vllm_ext/vllm_oot_patches.py new file mode 100644 index 000000000..8926424c6 --- /dev/null +++ b/auto_round_extension/vllm_ext/vllm_oot_patches.py @@ -0,0 +1,103 @@ +import vllm +from vllm.logger import init_logger +logger = init_logger(__name__) + +def oot_maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: + """Remap the name of FP8 k/v_scale parameters. + + This function handles the remapping of FP8 k/v_scale parameter names. + It detects if the given name ends with a suffix and attempts to remap + it to the expected name format in the model. If the remapped name is not + found in the params_dict, a warning is printed and None is returned. + + Args: + name (str): The original loaded checkpoint parameter name. + params_dict (dict): Dictionary containing the model's named parameters. + + Returns: + str: The remapped parameter name if successful, or the original name + if no remapping is needed. + None: If the remapped name is not found in params_dict. + """ + + if name.endswith(".kv_scale"): + logger.warning_once( + "DEPRECATED. Found kv_scale in the checkpoint. " + "This format is deprecated in favor of separate k_scale and " + "v_scale tensors and will be removed in a future release. " + "Functionally, we will remap kv_scale to k_scale and duplicate " + "k_scale to v_scale" + ) + # NOTE: we remap the deprecated kv_scale to k_scale + remapped_name = name.replace(".kv_scale", ".attn.k_scale") + if remapped_name not in params_dict: + logger.warning_once( + "Found kv_scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv_scale is not loaded.", # noqa: E501 + name, + remapped_name, + ) + return None + return remapped_name + + if any("mla_attn" in key for key in params_dict): + attn_str = "mla_attn.mla_attn" + logger.debug_once( + f"Found mla_attn with k_scale and v_scale in " + f"the checkpoint, using {attn_str} as attn_str" + ) + else: + attn_str = "attn" + # Define scale name mapping patterns in order of precedence + scale_mapping_patterns = [ + # AR format: + # .self_attn.{q,k,v}_scale -> + # .attn.{attn_str}.{q,k,v}_scale + ( + r"\.self_attn\.([qkv])_scale$", + rf".self_attn.{attn_str}.\1_scale", + ), + ( + r"\.self_attn\.([kv])_proj\.([kv])_scale$", + rf".self_attn.{attn_str}.\2_scale", + ), + # ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale -> + # .self_attn.attn.{k,v}_scale + ( + r"\.self_attn\.([kv])_proj\.([kv])_scale$", + rf".self_attn.{attn_str}.\2_scale", + ), + # QKV proj format: .self_attn.qkv_proj.{k,v}_scale -> + # .self_attn.attn.{k,v}_scale + (r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"), + # Qwen3 MoE format: .self_attn.qkqkv_proj.{k,v}_scale -> + # .self_attn.attn.{k,v}_scale + (r"\.self_attn\.qkqkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"), + # Default format: .{k,v}_scale -> .attn.{k,v}_scale + (r"\.([kv])_scale$", r".attn.\1_scale"), + ] + + # Check if name ends with k_scale or v_scale + if name.endswith((".k_scale", ".v_scale", ".q_scale")): + import regex as re + + for pattern, replacement in scale_mapping_patterns: + if re.search(pattern, name): + remapped_name = re.sub(pattern, replacement, name) + if remapped_name not in params_dict: + scale_type = name.split(".")[-1] + logger.warning_once( + "Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501 + scale_type, + name, + remapped_name, + scale_type, + ) + return None + return remapped_name + + # If there were no matches, return the untouched param name + return name + + +import vllm.model_executor.model_loader.weight_utils as vllm_weight_utils +vllm_weight_utils.maybe_remap_kv_scale_name = oot_maybe_remap_kv_scale_name \ No newline at end of file From af3fe70fa3fc383bb889ebc1245472ec2da3ba87 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 01:52:15 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round_extension/vllm_ext/__init__.py | 2 +- .../vllm_ext/vllm_oot_patches.py | 22 ++++++++++++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/auto_round_extension/vllm_ext/__init__.py b/auto_round_extension/vllm_ext/__init__.py index 7dbb55c6c..d05b3c731 100644 --- a/auto_round_extension/vllm_ext/__init__.py +++ b/auto_round_extension/vllm_ext/__init__.py @@ -21,7 +21,7 @@ def apply(): import auto_round_extension.vllm_ext.auto_round_ext import auto_round_extension.vllm_ext.envs_ext import auto_round_extension.vllm_ext.vllm_oot_patches + print("*****************************************************************************") print("* !!! VLLM_ENABLE_AR_EXT is set to 1, applying auto_round_vllm_extension *") print("*****************************************************************************") - diff --git a/auto_round_extension/vllm_ext/vllm_oot_patches.py b/auto_round_extension/vllm_ext/vllm_oot_patches.py index 8926424c6..af05cfc0c 100644 --- a/auto_round_extension/vllm_ext/vllm_oot_patches.py +++ b/auto_round_extension/vllm_ext/vllm_oot_patches.py @@ -1,7 +1,23 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import vllm from vllm.logger import init_logger + logger = init_logger(__name__) + def oot_maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: """Remap the name of FP8 k/v_scale parameters. @@ -42,8 +58,7 @@ def oot_maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: if any("mla_attn" in key for key in params_dict): attn_str = "mla_attn.mla_attn" logger.debug_once( - f"Found mla_attn with k_scale and v_scale in " - f"the checkpoint, using {attn_str} as attn_str" + f"Found mla_attn with k_scale and v_scale in " f"the checkpoint, using {attn_str} as attn_str" ) else: attn_str = "attn" @@ -100,4 +115,5 @@ def oot_maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: import vllm.model_executor.model_loader.weight_utils as vllm_weight_utils -vllm_weight_utils.maybe_remap_kv_scale_name = oot_maybe_remap_kv_scale_name \ No newline at end of file + +vllm_weight_utils.maybe_remap_kv_scale_name = oot_maybe_remap_kv_scale_name