From 09b3a1ceb3942e63aba7a9947c6bef4df4bca5b3 Mon Sep 17 00:00:00 2001 From: "Zhang, Weiwei1" Date: Sun, 14 Dec 2025 22:30:32 -0500 Subject: [PATCH 1/6] refine update_fused_layer_global_scales to fix device mismatch for nvfp UT Signed-off-by: Zhang, Weiwei1 --- auto_round/data_type/utils.py | 114 +++++++++++++++++----------------- 1 file changed, 56 insertions(+), 58 deletions(-) diff --git a/auto_round/data_type/utils.py b/auto_round/data_type/utils.py index faea677e3..801aba828 100644 --- a/auto_round/data_type/utils.py +++ b/auto_round/data_type/utils.py @@ -235,78 +235,76 @@ def get_gaudi_fp8_ste_func(): # please refer from https://github.com/vllm-project/llm-compressor/blob/ # 29f4d5644b48e9c8ebb7e36d5be9f7c92747ceb7/src/llmcompressor/modifiers/utils/helpers.py#L11 -def update_fused_layer_global_scales(submodule: torch.nn.Module, base_name="weight"): +def update_fused_layer_global_scales( + submodule: Module, + base_name: str = "weight", +): """ - When running NVFP4 quantization, update the global scale - such that q,k,v layers are treated as one tensor with the same - global_scale and gate_proj/up_proj layers are treated as one tensor - with the same global scale. This is requirement currently being set - by vLLM and may be removed in the future OR potentially make it - an optional step. - - :param model: model to quantize - base_name: op name for fuse usage, option: weight, input + Update global scales for fused layers under NVFP4 quantization. + + For attention layers: + - q/k/v projections share a single global scale. + + For MLP layers: + - gate_proj and up_proj share a single global scale. + + This behavior is currently required by vLLM and may become optional + in the future. """ global_scale_name = f"{base_name}_global_scale" - max_value_tensor = torch.ones(1, device="cpu", dtype=torch.float32) * torch.finfo(torch.float32).max - def _is_attention_module(module: Module): - return "attention" in module.__class__.__name__.lower() and ( - hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj") + def _collect_scales(mods: List[Module]) -> List[torch.Tensor]: + """Collect valid global_scale tensors from modules.""" + scales = [] + for m in mods: + if hasattr(m, global_scale_name): + scale = getattr(m, global_scale_name) + if isinstance(scale, torch.Tensor): + scales.append(scale) + return scales + + def _is_attention_module(m: Module) -> bool: + name = m.__class__.__name__.lower() + return "attention" in name and ( + hasattr(m, "q_proj") and hasattr(m, "k_proj") and hasattr(m, "v_proj") or hasattr(m, "qkv_proj") ) - def _is_mlp_module(module: Module): - return "mlp" in module.__class__.__name__.lower() and ( - hasattr(module, "gate_proj") or hasattr(module, "up_proj") - ) + def _is_mlp_module(m: Module) -> bool: + name = m.__class__.__name__.lower() + return "mlp" in name and hasattr(m, "gate_proj") and hasattr(m, "up_proj") + # ---------------- Attention ---------------- if _is_attention_module(submodule): - # already fused/treated as one layer + # Already fused if hasattr(submodule, "qkv_proj"): return - q_global_scale = getattr(submodule.q_proj, global_scale_name, max_value_tensor) - q_global_scale = max_value_tensor if q_global_scale is None else q_global_scale - k_global_scale = getattr(submodule.k_proj, global_scale_name, max_value_tensor) - k_global_scale = max_value_tensor if k_global_scale is None else k_global_scale - v_global_scale = getattr(submodule.v_proj, global_scale_name, max_value_tensor) - v_global_scale = max_value_tensor if v_global_scale is None else v_global_scale - - global_scale = torch.min( - torch.cat( - ( - q_global_scale.reshape(1), - k_global_scale.reshape(1), - v_global_scale.reshape(1), - ) - ) - ).reshape([1]) - - if math.isclose(global_scale.data, max_value_tensor.data, rel_tol=1e-9): + scales = _collect_scales([submodule.q_proj, submodule.k_proj, submodule.v_proj]) + if not scales: return - if hasattr(submodule.q_proj, global_scale_name): - setattr(submodule.q_proj, global_scale_name, global_scale.clone()) - if hasattr(submodule.k_proj, global_scale_name): - setattr(submodule.k_proj, global_scale_name, global_scale.clone()) - if hasattr(submodule.v_proj, global_scale_name): - setattr(submodule.v_proj, global_scale_name, global_scale.clone()) - del global_scale + device = scales[0].device + dtype = scales[0].dtype + + global_scale = torch.stack([s.to(device=device, dtype=dtype).reshape(1) for s in scales]).min(dim=0).values + for proj in (submodule.q_proj, submodule.k_proj, submodule.v_proj): + if hasattr(proj, global_scale_name): + setattr(proj, global_scale_name, global_scale.clone()) + + return + + # ---------------- MLP ---------------- if _is_mlp_module(submodule): - global_scale = torch.min( - torch.cat( - ( - getattr(submodule.gate_proj, global_scale_name, max_value_tensor).reshape(1), - getattr(submodule.up_proj, global_scale_name, max_value_tensor).reshape(1), - ) - ) - ).reshape([1]) - if math.isclose(global_scale.data, max_value_tensor.data, rel_tol=1e-9): + scales = _collect_scales([submodule.gate_proj, submodule.up_proj]) + if not scales: return - if hasattr(submodule.gate_proj, global_scale_name): - setattr(submodule.gate_proj, global_scale_name, global_scale.clone()) - if hasattr(submodule.up_proj, global_scale_name): - setattr(submodule.up_proj, global_scale_name, global_scale.clone()) - del global_scale + device = scales[0].device + dtype = scales[0].dtype + + global_scale = torch.stack([s.to(device=device, dtype=dtype).reshape(1) for s in scales]).min(dim=0).values + + for proj in (submodule.gate_proj, submodule.up_proj): + if hasattr(proj, global_scale_name): + setattr(proj, global_scale_name, global_scale.clone()) From 5d816459c8be7624eb98f49d6b42d57c797bec5a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Dec 2025 06:55:12 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/data_type/utils.py | 1 - auto_round/modelling/qwen3_vl_moe.py | 48 ++++++++++++---------------- 2 files changed, 20 insertions(+), 29 deletions(-) diff --git a/auto_round/data_type/utils.py b/auto_round/data_type/utils.py index b00aa873e..890329f5c 100644 --- a/auto_round/data_type/utils.py +++ b/auto_round/data_type/utils.py @@ -319,4 +319,3 @@ def _is_mlp_module(module: Module): for proj in (submodule.gate_proj, submodule.up_proj): if hasattr(proj, global_scale_name): setattr(proj, global_scale_name, global_scale.clone()) - diff --git a/auto_round/modelling/qwen3_vl_moe.py b/auto_round/modelling/qwen3_vl_moe.py index 03ec00d10..61c0db56e 100644 --- a/auto_round/modelling/qwen3_vl_moe.py +++ b/auto_round/modelling/qwen3_vl_moe.py @@ -13,14 +13,16 @@ # limitations under the License. import torch -from torch import nn -from transformers.activations import ACT2FN import transformers from packaging import version -from auto_round.utils import logger -from auto_round.utils import unsupported_meta_device +from torch import nn +from transformers.activations import ACT2FN + +from auto_round.utils import logger, unsupported_meta_device + transformers_version = version.parse(transformers.__version__) + def _update_parameter( module: torch.nn.Module, name: str, @@ -43,7 +45,7 @@ def __init__( self, original: "Qwen3VLMoeTextSparseMoeBlock", config: "Qwen3VLMoeConfig", - calibrate_all_experts: bool=True, + calibrate_all_experts: bool = True, ): super().__init__() text_config: "Qwen3VLMoeTextConfig" = config.get_text_config() @@ -56,10 +58,12 @@ def __init__( self.gate = original.gate self.calibrate_all_experts = calibrate_all_experts self.experts = SequentialQwen3VLMoeTextExperts(text_config, original.experts) - if not transformers_version <= version.parse("4.57.3"): # remove conversion_mapping for qwen3_vl_moe when transformers>=5.0 + if not transformers_version <= version.parse( + "4.57.3" + ): # remove conversion_mapping for qwen3_vl_moe when transformers>=5.0 from transformers.conversion_mapping import register_checkpoint_conversion_mapping - register_checkpoint_conversion_mapping(config.model_type, [], overwrite=True) + register_checkpoint_conversion_mapping(config.model_type, [], overwrite=True) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape @@ -67,15 +71,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) - routing_weights = torch.nn.functional.softmax( - router_logits, dim=1, dtype=torch.float - ) + routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float) # get topk experts per token # routing_weight: (num_tokens, top_k) # routing_indices: (num_tokens, top_k) - routing_weights, router_indices = torch.topk( - routing_weights, self.top_k, dim=-1 - ) + routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) routing_weights = routing_weights.to(hidden_states.dtype) @@ -87,9 +87,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # convert router indices into OHE list # reshape to be (num_experts, top_k, batch_size * sequence_length) - expert_mask = torch.nn.functional.one_hot( - router_indices, num_classes=self.num_experts - ).permute(2, 1, 0) + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) for expert_idx, expert_layer in enumerate(self.experts): idx, token_idx = torch.where(expert_mask[expert_idx].squeeze(0)) @@ -103,9 +101,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # if there are tokens meant for this expert, further scale the expert # output by the score weighted_output = expert_out * routing_weights[token_idx, idx, None] - next_states.index_add_( - 0, token_idx, weighted_output.to(hidden_states.dtype) - ) + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) next_states = next_states.reshape(batch_size, sequence_length, hidden_dim) if transformers_version <= version.parse("4.57.3"): @@ -125,10 +121,9 @@ def __init__(self, config, original): from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( Qwen3VLMoeTextMLP, ) - super().__init__( - [Qwen3VLMoeTextMLP(config, intermediate_size) for _ in range(self.num_experts)] - ) - + + super().__init__([Qwen3VLMoeTextMLP(config, intermediate_size) for _ in range(self.num_experts)]) + if not unsupported_meta_device(original): for i in range(self.num_experts): gate_up = original.gate_up_proj[i] @@ -141,9 +136,6 @@ def __init__(self, config, original): _update_parameter(self[i].up_proj, "weight", up_proj.t().contiguous()) _update_parameter(self[i].down_proj, "weight", down.t().contiguous()) + def get_replacement_info(config): - return ( - LinearQwen3VLMoeTextSparseMoeBlock, - config, - "Qwen3VLMoeTextSparseMoeBlock" - ) + return (LinearQwen3VLMoeTextSparseMoeBlock, config, "Qwen3VLMoeTextSparseMoeBlock") From 9f61c9e4d53812036ab76bf5b7a2e0c20b64fc18 Mon Sep 17 00:00:00 2001 From: "Zhang, Weiwei1" Date: Tue, 23 Dec 2025 02:02:35 -0500 Subject: [PATCH 3/6] fixtypo Signed-off-by: Zhang, Weiwei1 --- auto_round/data_type/utils.py | 1 - auto_round/modelling/qwen3_vl_moe.py | 56 ++++++++++++++-------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/auto_round/data_type/utils.py b/auto_round/data_type/utils.py index b00aa873e..890329f5c 100644 --- a/auto_round/data_type/utils.py +++ b/auto_round/data_type/utils.py @@ -319,4 +319,3 @@ def _is_mlp_module(module: Module): for proj in (submodule.gate_proj, submodule.up_proj): if hasattr(proj, global_scale_name): setattr(proj, global_scale_name, global_scale.clone()) - diff --git a/auto_round/modelling/qwen3_vl_moe.py b/auto_round/modelling/qwen3_vl_moe.py index 03ec00d10..b292934d3 100644 --- a/auto_round/modelling/qwen3_vl_moe.py +++ b/auto_round/modelling/qwen3_vl_moe.py @@ -13,14 +13,24 @@ # limitations under the License. import torch -from torch import nn -from transformers.activations import ACT2FN import transformers from packaging import version -from auto_round.utils import logger -from auto_round.utils import unsupported_meta_device +from torch import nn +from transformers.activations import ACT2FN + +from auto_round.utils import logger, unsupported_meta_device + transformers_version = version.parse(transformers.__version__) +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from transformers import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeTextSparseMoeBlock, + ) + + def _update_parameter( module: torch.nn.Module, name: str, @@ -43,7 +53,7 @@ def __init__( self, original: "Qwen3VLMoeTextSparseMoeBlock", config: "Qwen3VLMoeConfig", - calibrate_all_experts: bool=True, + calibrate_all_experts: bool = True, ): super().__init__() text_config: "Qwen3VLMoeTextConfig" = config.get_text_config() @@ -56,10 +66,12 @@ def __init__( self.gate = original.gate self.calibrate_all_experts = calibrate_all_experts self.experts = SequentialQwen3VLMoeTextExperts(text_config, original.experts) - if not transformers_version <= version.parse("4.57.3"): # remove conversion_mapping for qwen3_vl_moe when transformers>=5.0 + if not transformers_version <= version.parse( + "4.57.3" + ): # remove conversion_mapping for qwen3_vl_moe when transformers>=5.0 from transformers.conversion_mapping import register_checkpoint_conversion_mapping - register_checkpoint_conversion_mapping(config.model_type, [], overwrite=True) + register_checkpoint_conversion_mapping(config.model_type, [], overwrite=True) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape @@ -67,15 +79,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) - routing_weights = torch.nn.functional.softmax( - router_logits, dim=1, dtype=torch.float - ) + routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float) # get topk experts per token # routing_weight: (num_tokens, top_k) # routing_indices: (num_tokens, top_k) - routing_weights, router_indices = torch.topk( - routing_weights, self.top_k, dim=-1 - ) + routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) routing_weights = routing_weights.to(hidden_states.dtype) @@ -87,9 +95,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # convert router indices into OHE list # reshape to be (num_experts, top_k, batch_size * sequence_length) - expert_mask = torch.nn.functional.one_hot( - router_indices, num_classes=self.num_experts - ).permute(2, 1, 0) + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) for expert_idx, expert_layer in enumerate(self.experts): idx, token_idx = torch.where(expert_mask[expert_idx].squeeze(0)) @@ -103,9 +109,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # if there are tokens meant for this expert, further scale the expert # output by the score weighted_output = expert_out * routing_weights[token_idx, idx, None] - next_states.index_add_( - 0, token_idx, weighted_output.to(hidden_states.dtype) - ) + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) next_states = next_states.reshape(batch_size, sequence_length, hidden_dim) if transformers_version <= version.parse("4.57.3"): @@ -125,10 +129,9 @@ def __init__(self, config, original): from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( Qwen3VLMoeTextMLP, ) - super().__init__( - [Qwen3VLMoeTextMLP(config, intermediate_size) for _ in range(self.num_experts)] - ) - + + super().__init__([Qwen3VLMoeTextMLP(config, intermediate_size) for _ in range(self.num_experts)]) + if not unsupported_meta_device(original): for i in range(self.num_experts): gate_up = original.gate_up_proj[i] @@ -141,9 +144,6 @@ def __init__(self, config, original): _update_parameter(self[i].up_proj, "weight", up_proj.t().contiguous()) _update_parameter(self[i].down_proj, "weight", down.t().contiguous()) + def get_replacement_info(config): - return ( - LinearQwen3VLMoeTextSparseMoeBlock, - config, - "Qwen3VLMoeTextSparseMoeBlock" - ) + return (LinearQwen3VLMoeTextSparseMoeBlock, config, "Qwen3VLMoeTextSparseMoeBlock") From 2afa2698c4ecf94291e59cc82df37c1028b9d1c8 Mon Sep 17 00:00:00 2001 From: Weiwei Date: Wed, 24 Dec 2025 10:34:14 +0800 Subject: [PATCH 4/6] Update auto_round/modelling/qwen3_vl_moe.py Co-authored-by: Yi Liu --- auto_round/modelling/qwen3_vl_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/auto_round/modelling/qwen3_vl_moe.py b/auto_round/modelling/qwen3_vl_moe.py index b292934d3..b024498eb 100644 --- a/auto_round/modelling/qwen3_vl_moe.py +++ b/auto_round/modelling/qwen3_vl_moe.py @@ -117,7 +117,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: else: return next_states - def restore(self, original: torch.nn.Module) -> torch.nn.Module: return original From c268e3dc44fe7327e88df06f532b7e1c62ab5f19 Mon Sep 17 00:00:00 2001 From: Weiwei Date: Wed, 24 Dec 2025 10:35:21 +0800 Subject: [PATCH 5/6] set calib_all_experts to false Co-authored-by: Yi Liu --- auto_round/modelling/qwen3_vl_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/modelling/qwen3_vl_moe.py b/auto_round/modelling/qwen3_vl_moe.py index b024498eb..206ec3d24 100644 --- a/auto_round/modelling/qwen3_vl_moe.py +++ b/auto_round/modelling/qwen3_vl_moe.py @@ -53,7 +53,7 @@ def __init__( self, original: "Qwen3VLMoeTextSparseMoeBlock", config: "Qwen3VLMoeConfig", - calibrate_all_experts: bool = True, + calibrate_all_experts: bool = False, ): super().__init__() text_config: "Qwen3VLMoeTextConfig" = config.get_text_config() From b9b8914a6f8425f54b3b08866061a1d8275b7516 Mon Sep 17 00:00:00 2001 From: Weiwei Date: Wed, 24 Dec 2025 10:37:55 +0800 Subject: [PATCH 6/6] fix typo --- auto_round/modelling/qwen3_vl_moe.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/auto_round/modelling/qwen3_vl_moe.py b/auto_round/modelling/qwen3_vl_moe.py index 206ec3d24..9a47871e5 100644 --- a/auto_round/modelling/qwen3_vl_moe.py +++ b/auto_round/modelling/qwen3_vl_moe.py @@ -117,8 +117,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: else: return next_states - return original - class SequentialQwen3VLMoeTextExperts(torch.nn.ModuleList): def __init__(self, config, original):