Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion auto_round/experimental/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def normalize_static_kv_dtype(static_kv_dtype: str | torch.dtype) -> torch.dtype
def is_attention_module(module: torch.nn.Module):
# FIXME: Handle this better.
return "attention" in module.__class__.__name__.lower() and (
hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj")
hasattr(module, "k_proj")
or hasattr(module, "v_proj")
or hasattr(module, "qkv_proj")
or hasattr(module, "kv_b_proj") # for DeepSpeed
)


Expand Down
1 change: 1 addition & 0 deletions auto_round_extension/vllm_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
def apply():
import auto_round_extension.vllm_ext.auto_round_ext
import auto_round_extension.vllm_ext.envs_ext
import auto_round_extension.vllm_ext.vllm_oot_patches

print("*****************************************************************************")
print("* !!! VLLM_ENABLE_AR_EXT is set to 1, applying auto_round_vllm_extension *")
Expand Down
119 changes: 119 additions & 0 deletions auto_round_extension/vllm_ext/vllm_oot_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import vllm
from vllm.logger import init_logger

logger = init_logger(__name__)


def oot_maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
"""Remap the name of FP8 k/v_scale parameters.

This function handles the remapping of FP8 k/v_scale parameter names.
It detects if the given name ends with a suffix and attempts to remap
it to the expected name format in the model. If the remapped name is not
found in the params_dict, a warning is printed and None is returned.

Args:
name (str): The original loaded checkpoint parameter name.
params_dict (dict): Dictionary containing the model's named parameters.

Returns:
str: The remapped parameter name if successful, or the original name
if no remapping is needed.
None: If the remapped name is not found in params_dict.
"""

if name.endswith(".kv_scale"):
logger.warning_once(
"DEPRECATED. Found kv_scale in the checkpoint. "
"This format is deprecated in favor of separate k_scale and "
"v_scale tensors and will be removed in a future release. "
"Functionally, we will remap kv_scale to k_scale and duplicate "
"k_scale to v_scale"
)
# NOTE: we remap the deprecated kv_scale to k_scale
remapped_name = name.replace(".kv_scale", ".attn.k_scale")
if remapped_name not in params_dict:
logger.warning_once(
"Found kv_scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv_scale is not loaded.", # noqa: E501
name,
remapped_name,
)
return None
return remapped_name

if any("mla_attn" in key for key in params_dict):
attn_str = "mla_attn.mla_attn"
logger.debug_once(
f"Found mla_attn with k_scale and v_scale in " f"the checkpoint, using {attn_str} as attn_str"
)
else:
attn_str = "attn"
# Define scale name mapping patterns in order of precedence
scale_mapping_patterns = [
# AR format:
# .self_attn.{q,k,v}_scale ->
# .attn.{attn_str}.{q,k,v}_scale
(
r"\.self_attn\.([qkv])_scale$",
rf".self_attn.{attn_str}.\1_scale",
),
(
r"\.self_attn\.([kv])_proj\.([kv])_scale$",
rf".self_attn.{attn_str}.\2_scale",
),
# ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
# .self_attn.attn.{k,v}_scale
(
r"\.self_attn\.([kv])_proj\.([kv])_scale$",
rf".self_attn.{attn_str}.\2_scale",
),
# QKV proj format: .self_attn.qkv_proj.{k,v}_scale ->
# .self_attn.attn.{k,v}_scale
(r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),
# Qwen3 MoE format: .self_attn.qkqkv_proj.{k,v}_scale ->
# .self_attn.attn.{k,v}_scale
(r"\.self_attn\.qkqkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),
# Default format: .{k,v}_scale -> .attn.{k,v}_scale
(r"\.([kv])_scale$", r".attn.\1_scale"),
]

# Check if name ends with k_scale or v_scale
if name.endswith((".k_scale", ".v_scale", ".q_scale")):
import regex as re

for pattern, replacement in scale_mapping_patterns:
if re.search(pattern, name):
remapped_name = re.sub(pattern, replacement, name)
if remapped_name not in params_dict:
scale_type = name.split(".")[-1]
logger.warning_once(
"Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
scale_type,
name,
remapped_name,
scale_type,
)
return None
return remapped_name

# If there were no matches, return the untouched param name
return name


import vllm.model_executor.model_loader.weight_utils as vllm_weight_utils

vllm_weight_utils.maybe_remap_kv_scale_name = oot_maybe_remap_kv_scale_name