diff --git a/auto_round/experimental/utils.py b/auto_round/experimental/utils.py index 545ebd401..4af584549 100644 --- a/auto_round/experimental/utils.py +++ b/auto_round/experimental/utils.py @@ -71,7 +71,10 @@ def normalize_static_kv_dtype(static_kv_dtype: str | torch.dtype) -> torch.dtype def is_attention_module(module: torch.nn.Module): # FIXME: Handle this better. return "attention" in module.__class__.__name__.lower() and ( - hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj") + hasattr(module, "k_proj") + or hasattr(module, "v_proj") + or hasattr(module, "qkv_proj") + or hasattr(module, "kv_b_proj") # for DeepSpeed ) diff --git a/auto_round_extension/vllm_ext/__init__.py b/auto_round_extension/vllm_ext/__init__.py index 748e917e2..d05b3c731 100644 --- a/auto_round_extension/vllm_ext/__init__.py +++ b/auto_round_extension/vllm_ext/__init__.py @@ -20,6 +20,7 @@ def apply(): import auto_round_extension.vllm_ext.auto_round_ext import auto_round_extension.vllm_ext.envs_ext + import auto_round_extension.vllm_ext.vllm_oot_patches print("*****************************************************************************") print("* !!! VLLM_ENABLE_AR_EXT is set to 1, applying auto_round_vllm_extension *") diff --git a/auto_round_extension/vllm_ext/vllm_oot_patches.py b/auto_round_extension/vllm_ext/vllm_oot_patches.py new file mode 100644 index 000000000..af05cfc0c --- /dev/null +++ b/auto_round_extension/vllm_ext/vllm_oot_patches.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import vllm +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def oot_maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: + """Remap the name of FP8 k/v_scale parameters. + + This function handles the remapping of FP8 k/v_scale parameter names. + It detects if the given name ends with a suffix and attempts to remap + it to the expected name format in the model. If the remapped name is not + found in the params_dict, a warning is printed and None is returned. + + Args: + name (str): The original loaded checkpoint parameter name. + params_dict (dict): Dictionary containing the model's named parameters. + + Returns: + str: The remapped parameter name if successful, or the original name + if no remapping is needed. + None: If the remapped name is not found in params_dict. + """ + + if name.endswith(".kv_scale"): + logger.warning_once( + "DEPRECATED. Found kv_scale in the checkpoint. " + "This format is deprecated in favor of separate k_scale and " + "v_scale tensors and will be removed in a future release. " + "Functionally, we will remap kv_scale to k_scale and duplicate " + "k_scale to v_scale" + ) + # NOTE: we remap the deprecated kv_scale to k_scale + remapped_name = name.replace(".kv_scale", ".attn.k_scale") + if remapped_name not in params_dict: + logger.warning_once( + "Found kv_scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv_scale is not loaded.", # noqa: E501 + name, + remapped_name, + ) + return None + return remapped_name + + if any("mla_attn" in key for key in params_dict): + attn_str = "mla_attn.mla_attn" + logger.debug_once( + f"Found mla_attn with k_scale and v_scale in " f"the checkpoint, using {attn_str} as attn_str" + ) + else: + attn_str = "attn" + # Define scale name mapping patterns in order of precedence + scale_mapping_patterns = [ + # AR format: + # .self_attn.{q,k,v}_scale -> + # .attn.{attn_str}.{q,k,v}_scale + ( + r"\.self_attn\.([qkv])_scale$", + rf".self_attn.{attn_str}.\1_scale", + ), + ( + r"\.self_attn\.([kv])_proj\.([kv])_scale$", + rf".self_attn.{attn_str}.\2_scale", + ), + # ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale -> + # .self_attn.attn.{k,v}_scale + ( + r"\.self_attn\.([kv])_proj\.([kv])_scale$", + rf".self_attn.{attn_str}.\2_scale", + ), + # QKV proj format: .self_attn.qkv_proj.{k,v}_scale -> + # .self_attn.attn.{k,v}_scale + (r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"), + # Qwen3 MoE format: .self_attn.qkqkv_proj.{k,v}_scale -> + # .self_attn.attn.{k,v}_scale + (r"\.self_attn\.qkqkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"), + # Default format: .{k,v}_scale -> .attn.{k,v}_scale + (r"\.([kv])_scale$", r".attn.\1_scale"), + ] + + # Check if name ends with k_scale or v_scale + if name.endswith((".k_scale", ".v_scale", ".q_scale")): + import regex as re + + for pattern, replacement in scale_mapping_patterns: + if re.search(pattern, name): + remapped_name = re.sub(pattern, replacement, name) + if remapped_name not in params_dict: + scale_type = name.split(".")[-1] + logger.warning_once( + "Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501 + scale_type, + name, + remapped_name, + scale_type, + ) + return None + return remapped_name + + # If there were no matches, return the untouched param name + return name + + +import vllm.model_executor.model_loader.weight_utils as vllm_weight_utils + +vllm_weight_utils.maybe_remap_kv_scale_name = oot_maybe_remap_kv_scale_name