diff --git a/auto_round/experimental/qmodules/fp4_utils.py b/auto_round/experimental/qmodules/fp4_utils.py index 24ab8b534..e755a2b38 100644 --- a/auto_round/experimental/qmodules/fp4_utils.py +++ b/auto_round/experimental/qmodules/fp4_utils.py @@ -31,7 +31,18 @@ import torch -kE2M1ToFloat = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32) +_DEVICE_E2M1_TENSORS = {} + +# Constants for FP4 values (E2M1 format) +_E2M1_VALUES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] + + +def get_e2m1_tensor(device): + """Get device-specific E2M1 lookup tensor, creating it if needed.""" + device_str = str(device) + if device_str not in _DEVICE_E2M1_TENSORS: + _DEVICE_E2M1_TENSORS[device_str] = torch.tensor(_E2M1_VALUES, dtype=torch.float32, device=device) + return _DEVICE_E2M1_TENSORS[device_str] def unpack_fp4_from_uint8( @@ -91,7 +102,7 @@ def _unpack_fp4_from_uint8( abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices # Device-aware lookup and sign application - kE2M1 = kE2M1ToFloat.to(device=a.device) + kE2M1 = get_e2m1_tensor(device=a.device) values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) # Reshape to final form diff --git a/auto_round/modelling/qwen3_vl_moe.py b/auto_round/modelling/qwen3_vl_moe.py new file mode 100644 index 000000000..9a47871e5 --- /dev/null +++ b/auto_round/modelling/qwen3_vl_moe.py @@ -0,0 +1,146 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import transformers +from packaging import version +from torch import nn +from transformers.activations import ACT2FN + +from auto_round.utils import logger, unsupported_meta_device + +transformers_version = version.parse(transformers.__version__) + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from transformers import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeTextSparseMoeBlock, + ) + + +def _update_parameter( + module: torch.nn.Module, + name: str, + data: torch.Tensor, +) -> None: + param = getattr(module, name) + param.data.copy_(data) + + +# Adapted from https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/qwen3_vl_moe.py +class LinearQwen3VLMoeTextSparseMoeBlock(torch.nn.Module): + """ + Calibration version of Qwen3VLMoeTextSparseMoeBlock that sends all tokens to all + experts. + """ + + is_permanent = True + + def __init__( + self, + original: "Qwen3VLMoeTextSparseMoeBlock", + config: "Qwen3VLMoeConfig", + calibrate_all_experts: bool = False, + ): + super().__init__() + text_config: "Qwen3VLMoeTextConfig" = config.get_text_config() + + self.hidden_size = text_config.hidden_size + self.num_experts = text_config.num_experts + self.top_k = original.top_k + # Note: gate was changed to be a Linear layer in transformers==4.57.0 + # https://github.com/JJJYmmm/transformers/commit/f5dea1c694af8c994c769170813a8702332119ee + self.gate = original.gate + self.calibrate_all_experts = calibrate_all_experts + self.experts = SequentialQwen3VLMoeTextExperts(text_config, original.experts) + if not transformers_version <= version.parse( + "4.57.3" + ): # remove conversion_mapping for qwen3_vl_moe when transformers>=5.0 + from transformers.conversion_mapping import register_checkpoint_conversion_mapping + + register_checkpoint_conversion_mapping(config.model_type, [], overwrite=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.reshape(-1, hidden_dim) + + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float) + # get topk experts per token + # routing_weight: (num_tokens, top_k) + # routing_indices: (num_tokens, top_k) + routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + routing_weights = routing_weights.to(hidden_states.dtype) + + next_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # convert router indices into OHE list + # reshape to be (num_experts, top_k, batch_size * sequence_length) + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx, expert_layer in enumerate(self.experts): + idx, token_idx = torch.where(expert_mask[expert_idx].squeeze(0)) + + if self.calibrate_all_experts: + expert_out = expert_layer(hidden_states)[token_idx] + else: + expert_out = expert_layer(hidden_states[token_idx]) + + if len(token_idx) > 0: + # if there are tokens meant for this expert, further scale the expert + # output by the score + weighted_output = expert_out * routing_weights[token_idx, idx, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) + next_states = next_states.reshape(batch_size, sequence_length, hidden_dim) + + if transformers_version <= version.parse("4.57.3"): + return next_states, router_logits + else: + return next_states + + +class SequentialQwen3VLMoeTextExperts(torch.nn.ModuleList): + def __init__(self, config, original): + super().__init__() + self.num_experts = original.gate_up_proj.shape[0] + intermediate_size = config.moe_intermediate_size + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( + Qwen3VLMoeTextMLP, + ) + + super().__init__([Qwen3VLMoeTextMLP(config, intermediate_size) for _ in range(self.num_experts)]) + + if not unsupported_meta_device(original): + for i in range(self.num_experts): + gate_up = original.gate_up_proj[i] + down = original.down_proj[i] + + gate_proj = gate_up[:, :intermediate_size] + up_proj = gate_up[:, intermediate_size:] + + _update_parameter(self[i].gate_proj, "weight", gate_proj.t().contiguous()) + _update_parameter(self[i].up_proj, "weight", up_proj.t().contiguous()) + _update_parameter(self[i].down_proj, "weight", down.t().contiguous()) + + +def get_replacement_info(config): + return (LinearQwen3VLMoeTextSparseMoeBlock, config, "Qwen3VLMoeTextSparseMoeBlock") diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py index 1c335f230..adfc14f79 100644 --- a/auto_round/special_model_handler.py +++ b/auto_round/special_model_handler.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch + import auto_round.modelling as auto_round_modelling from auto_round.utils import LazyImport, logger, unsupported_meta_device @@ -28,6 +30,7 @@ "llama4", "internvl_chat", "glm4v_moe", + "qwen3_vl_moe", ] NOT_SUPPORT_ONLY_TEXT_MODELS = ["mllama", "mistral3_2"] @@ -37,7 +40,7 @@ } SPECIAL_SHARED_CACHE_KEYS["MiniMaxText01ForCausalLM"] = ("slope_rate",) -CONVERT_EXPERT_TO_LINEAR_MODELS = ["llama4", "gpt_oss"] +CONVERT_EXPERT_TO_LINEAR_MODELS = ["llama4", "gpt_oss", "qwen3_vl_moe"] MISTRAL_3_2_MODELS = ["Mistral-Small-3.2", "Magistral-Small", "Devstral-Small"] @@ -47,6 +50,7 @@ def _get_moe_converter(config): moe_converters = { "gpt_oss": LazyImport("auto_round.modelling.gpt_oss.get_replacement_info"), "llama4": LazyImport("auto_round.modelling.llama4.get_replacement_info"), + "qwen3_vl_moe": LazyImport("auto_round.modelling.qwen3_vl_moe.get_replacement_info"), } # Retrieve the appropriate function based on model_type diff --git a/test/test_cpu/test_moe_model.py b/test/test_cpu/test_moe_model.py index c88571346..24bd39004 100644 --- a/test/test_cpu/test_moe_model.py +++ b/test/test_cpu/test_moe_model.py @@ -1,8 +1,9 @@ import shutil import pytest -from transformers import AutoConfig, AutoTokenizer, Llama4ForConditionalGeneration +from transformers import AutoConfig, AutoProcessor, AutoTokenizer, Llama4ForConditionalGeneration from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM +from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration from auto_round import AutoRound @@ -32,6 +33,21 @@ def setup_llama4(): return model, tokenizer, output_dir, config +@pytest.fixture +def setup_qwen3_vl_moe(): + """Fixture to set up the qwen3_vl_moe model and tokenizer.""" + model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-VL-30B-A3B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name) + config.vision_config.num_hidden_layers = 1 + config.text_config.num_hidden_layers = 1 + config.num_hidden_layers = 1 # Reduce layers for testing + processor = AutoProcessor.from_pretrained(model_name) + model = Qwen3VLMoeForConditionalGeneration(config) + output_dir = "/tmp/test_quantized_qwen3_vl_moe" + return model, tokenizer, processor, output_dir, config + + def quantize_model(model, tokenizer, output_dir, scheme, iters=0): """Helper function to quantize the model with the given scheme.""" autoround = AutoRound( @@ -76,7 +92,7 @@ def test_llama4(setup_llama4): delattr(model.config.text_config, "moe_layers") delattr(model.config.text_config, "layer_types") - quantized_model = quantize_model(model, tokenizer, output_dir, "MXFP4") + quantized_model = quantize_model(model, tokenizer, output_dir, "MXFP4", iters=1) # Ensure the quantized model is not None assert quantized_model is not None, "Quantized model should not be None." @@ -88,3 +104,32 @@ def test_llama4(setup_llama4): assert (loaded_m.weight_packed.to("cpu") == m.weight_packed.to("cpu")).all() # clean the output directory after test shutil.rmtree(output_dir, ignore_errors=True) + + +def test_qwen3_vl_moe_mxfp(setup_qwen3_vl_moe): + model, tokenizer, processor, output_dir, config = setup_qwen3_vl_moe + autoround = AutoRound( + model, + tokenizer=tokenizer, + processor=processor, + scheme="MXFP4", + nsamples=2, + seqlen=32, + iters=1, + fp_layers="self_attn,lm_head,mlp.gate", + ) + quantized_model, _ = autoround.quantize_and_save(format="auto_round", output_dir=output_dir) + assert quantized_model is not None, "Quantized model should not be None." + loaded_model = Qwen3VLMoeForConditionalGeneration.from_pretrained(output_dir, device_map="cpu") + + for n, m in quantized_model.named_modules(): + if m.__class__.__name__ == "QuantLinear": + loaded_m = loaded_model.get_submodule(n) + assert (loaded_m.weight_packed.to("cpu") == m.weight_packed.to("cpu")).all() + # test generation + tokenizer = AutoTokenizer.from_pretrained(output_dir) + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(device=loaded_model.device) + print(tokenizer.decode(loaded_model.generate(**inputs, max_new_tokens=50)[0])) + # clean the output directory after test + shutil.rmtree(output_dir, ignore_errors=True) diff --git a/test/test_cuda/test_moe_model.py b/test/test_cuda/test_moe_model.py index d3a277de5..1865519c9 100644 --- a/test/test_cuda/test_moe_model.py +++ b/test/test_cuda/test_moe_model.py @@ -2,8 +2,9 @@ import pytest import torch -from transformers import AutoConfig, AutoTokenizer, Llama4ForConditionalGeneration +from transformers import AutoConfig, AutoProcessor, AutoTokenizer, Llama4ForConditionalGeneration from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM +from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration from auto_round import AutoRound @@ -33,6 +34,21 @@ def setup_llama4(): return model, tokenizer, output_dir, config +@pytest.fixture +def setup_qwen3_vl_moe(): + """Fixture to set up the qwen3_vl_moe model and tokenizer.""" + model_name = "/models/Qwen3-VL-30B-A3B-Instruct" + tokenizer = AutoTokenizer.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name) + config.vision_config.num_hidden_layers = 1 + config.text_config.num_hidden_layers = 1 + config.num_hidden_layers = 1 # Reduce layers for testing + processor = AutoProcessor.from_pretrained(model_name) + model = Qwen3VLMoeForConditionalGeneration(config) + output_dir = "/tmp/test_quantized_qwen3_vl_moe" + return model, tokenizer, processor, output_dir, config + + def quantize_model(model, tokenizer, output_dir, scheme, iters=0): """Helper function to quantize the model with the given scheme.""" autoround = AutoRound( @@ -103,3 +119,33 @@ def test_llama4(setup_llama4): # clean the output directory after test shutil.rmtree(output_dir, ignore_errors=True) + + +def test_qwen3_vl_moe_mxfp(setup_qwen3_vl_moe): + model, tokenizer, processor, output_dir, config = setup_qwen3_vl_moe + autoround = AutoRound( + model, + tokenizer=tokenizer, + processor=processor, + scheme="MXFP4", + nsamples=2, + seqlen=32, + iters=1, + fp_layers="self_attn,lm_head,mlp.gate", + ) + quantized_model, _ = autoround.quantize_and_save(format="auto_round", output_dir=output_dir) + assert quantized_model is not None, "Quantized model should not be None." + loaded_model = Qwen3VLMoeForConditionalGeneration.from_pretrained(output_dir) + loaded_model.to("cuda") + quantized_model.to("cuda") + for n, m in quantized_model.named_modules(): + if m.__class__.__name__ == "QuantLinear": + loaded_m = loaded_model.get_submodule(n) + assert (loaded_m.weight_packed == m.weight_packed).all() + # test generation + tokenizer = AutoTokenizer.from_pretrained(output_dir) + text = "There is a girl who likes adventure," + inputs = tokenizer(text, return_tensors="pt").to(device=loaded_model.device) + print(tokenizer.decode(loaded_model.generate(**inputs, max_new_tokens=50)[0])) + # clean the output directory after test + shutil.rmtree(output_dir, ignore_errors=True)