Skip to content
15 changes: 13 additions & 2 deletions auto_round/experimental/qmodules/fp4_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,18 @@

import torch

kE2M1ToFloat = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32)
_DEVICE_E2M1_TENSORS = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are default values ​​not needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor-type global variables may lead to device-related issues in future Transformers releases (e.g., v5.0.0); therefore, they are initialized lazily upon first use.


# Constants for FP4 values (E2M1 format)
_E2M1_VALUES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]


def get_e2m1_tensor(device):
"""Get device-specific E2M1 lookup tensor, creating it if needed."""
device_str = str(device)
if device_str not in _DEVICE_E2M1_TENSORS:
_DEVICE_E2M1_TENSORS[device_str] = torch.tensor(_E2M1_VALUES, dtype=torch.float32, device=device)
return _DEVICE_E2M1_TENSORS[device_str]


def unpack_fp4_from_uint8(
Expand Down Expand Up @@ -91,7 +102,7 @@ def _unpack_fp4_from_uint8(
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices

# Device-aware lookup and sign application
kE2M1 = kE2M1ToFloat.to(device=a.device)
kE2M1 = get_e2m1_tensor(device=a.device)
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)

# Reshape to final form
Expand Down
146 changes: 146 additions & 0 deletions auto_round/modelling/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import transformers
from packaging import version
from torch import nn
from transformers.activations import ACT2FN

from auto_round.utils import logger, unsupported_meta_device

transformers_version = version.parse(transformers.__version__)

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from transformers import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeTextSparseMoeBlock,
)


def _update_parameter(
module: torch.nn.Module,
name: str,
data: torch.Tensor,
) -> None:
param = getattr(module, name)
param.data.copy_(data)


# Adapted from https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/qwen3_vl_moe.py
class LinearQwen3VLMoeTextSparseMoeBlock(torch.nn.Module):
"""
Calibration version of Qwen3VLMoeTextSparseMoeBlock that sends all tokens to all
experts.
"""

is_permanent = True

def __init__(
self,
original: "Qwen3VLMoeTextSparseMoeBlock",
config: "Qwen3VLMoeConfig",
calibrate_all_experts: bool = False,
):
super().__init__()
text_config: "Qwen3VLMoeTextConfig" = config.get_text_config()

self.hidden_size = text_config.hidden_size
self.num_experts = text_config.num_experts
self.top_k = original.top_k
# Note: gate was changed to be a Linear layer in transformers==4.57.0
# https://github.com/JJJYmmm/transformers/commit/f5dea1c694af8c994c769170813a8702332119ee
self.gate = original.gate
self.calibrate_all_experts = calibrate_all_experts
self.experts = SequentialQwen3VLMoeTextExperts(text_config, original.experts)
if not transformers_version <= version.parse(
"4.57.3"
): # remove conversion_mapping for qwen3_vl_moe when transformers>=5.0
from transformers.conversion_mapping import register_checkpoint_conversion_mapping

register_checkpoint_conversion_mapping(config.model_type, [], overwrite=True)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_dim)

# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
# get topk experts per token
# routing_weight: (num_tokens, top_k)
# routing_indices: (num_tokens, top_k)
routing_weights, router_indices = torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)

next_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)

# convert router indices into OHE list
# reshape to be (num_experts, top_k, batch_size * sequence_length)
expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)

for expert_idx, expert_layer in enumerate(self.experts):
idx, token_idx = torch.where(expert_mask[expert_idx].squeeze(0))

if self.calibrate_all_experts:
expert_out = expert_layer(hidden_states)[token_idx]
else:
expert_out = expert_layer(hidden_states[token_idx])

if len(token_idx) > 0:
# if there are tokens meant for this expert, further scale the expert
# output by the score
weighted_output = expert_out * routing_weights[token_idx, idx, None]
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
next_states = next_states.reshape(batch_size, sequence_length, hidden_dim)

if transformers_version <= version.parse("4.57.3"):
return next_states, router_logits
else:
return next_states


class SequentialQwen3VLMoeTextExperts(torch.nn.ModuleList):
def __init__(self, config, original):
super().__init__()
self.num_experts = original.gate_up_proj.shape[0]
intermediate_size = config.moe_intermediate_size
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeTextMLP,
)

super().__init__([Qwen3VLMoeTextMLP(config, intermediate_size) for _ in range(self.num_experts)])

if not unsupported_meta_device(original):
for i in range(self.num_experts):
gate_up = original.gate_up_proj[i]
down = original.down_proj[i]

gate_proj = gate_up[:, :intermediate_size]
up_proj = gate_up[:, intermediate_size:]

_update_parameter(self[i].gate_proj, "weight", gate_proj.t().contiguous())
_update_parameter(self[i].up_proj, "weight", up_proj.t().contiguous())
_update_parameter(self[i].down_proj, "weight", down.t().contiguous())


def get_replacement_info(config):
return (LinearQwen3VLMoeTextSparseMoeBlock, config, "Qwen3VLMoeTextSparseMoeBlock")
6 changes: 5 additions & 1 deletion auto_round/special_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import auto_round.modelling as auto_round_modelling
from auto_round.utils import LazyImport, logger, unsupported_meta_device

Expand All @@ -28,6 +30,7 @@
"llama4",
"internvl_chat",
"glm4v_moe",
"qwen3_vl_moe",
]

NOT_SUPPORT_ONLY_TEXT_MODELS = ["mllama", "mistral3_2"]
Expand All @@ -37,7 +40,7 @@
}
SPECIAL_SHARED_CACHE_KEYS["MiniMaxText01ForCausalLM"] = ("slope_rate",)

CONVERT_EXPERT_TO_LINEAR_MODELS = ["llama4", "gpt_oss"]
CONVERT_EXPERT_TO_LINEAR_MODELS = ["llama4", "gpt_oss", "qwen3_vl_moe"]

MISTRAL_3_2_MODELS = ["Mistral-Small-3.2", "Magistral-Small", "Devstral-Small"]

Expand All @@ -47,6 +50,7 @@ def _get_moe_converter(config):
moe_converters = {
"gpt_oss": LazyImport("auto_round.modelling.gpt_oss.get_replacement_info"),
"llama4": LazyImport("auto_round.modelling.llama4.get_replacement_info"),
"qwen3_vl_moe": LazyImport("auto_round.modelling.qwen3_vl_moe.get_replacement_info"),
}

# Retrieve the appropriate function based on model_type
Expand Down
49 changes: 47 additions & 2 deletions test/test_cpu/test_moe_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import shutil

import pytest
from transformers import AutoConfig, AutoTokenizer, Llama4ForConditionalGeneration
from transformers import AutoConfig, AutoProcessor, AutoTokenizer, Llama4ForConditionalGeneration
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration

from auto_round import AutoRound

Expand Down Expand Up @@ -32,6 +33,21 @@ def setup_llama4():
return model, tokenizer, output_dir, config


@pytest.fixture
def setup_qwen3_vl_moe():
"""Fixture to set up the qwen3_vl_moe model and tokenizer."""
model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-VL-30B-A3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
config.vision_config.num_hidden_layers = 1
config.text_config.num_hidden_layers = 1
config.num_hidden_layers = 1 # Reduce layers for testing
processor = AutoProcessor.from_pretrained(model_name)
model = Qwen3VLMoeForConditionalGeneration(config)
output_dir = "/tmp/test_quantized_qwen3_vl_moe"
return model, tokenizer, processor, output_dir, config


def quantize_model(model, tokenizer, output_dir, scheme, iters=0):
"""Helper function to quantize the model with the given scheme."""
autoround = AutoRound(
Expand Down Expand Up @@ -76,7 +92,7 @@ def test_llama4(setup_llama4):
delattr(model.config.text_config, "moe_layers")
delattr(model.config.text_config, "layer_types")

quantized_model = quantize_model(model, tokenizer, output_dir, "MXFP4")
quantized_model = quantize_model(model, tokenizer, output_dir, "MXFP4", iters=1)

# Ensure the quantized model is not None
assert quantized_model is not None, "Quantized model should not be None."
Expand All @@ -88,3 +104,32 @@ def test_llama4(setup_llama4):
assert (loaded_m.weight_packed.to("cpu") == m.weight_packed.to("cpu")).all()
# clean the output directory after test
shutil.rmtree(output_dir, ignore_errors=True)


def test_qwen3_vl_moe_mxfp(setup_qwen3_vl_moe):
model, tokenizer, processor, output_dir, config = setup_qwen3_vl_moe
autoround = AutoRound(
model,
tokenizer=tokenizer,
processor=processor,
scheme="MXFP4",
nsamples=2,
seqlen=32,
iters=1,
fp_layers="self_attn,lm_head,mlp.gate",
)
quantized_model, _ = autoround.quantize_and_save(format="auto_round", output_dir=output_dir)
assert quantized_model is not None, "Quantized model should not be None."
loaded_model = Qwen3VLMoeForConditionalGeneration.from_pretrained(output_dir, device_map="cpu")

for n, m in quantized_model.named_modules():
if m.__class__.__name__ == "QuantLinear":
loaded_m = loaded_model.get_submodule(n)
assert (loaded_m.weight_packed.to("cpu") == m.weight_packed.to("cpu")).all()
# test generation
tokenizer = AutoTokenizer.from_pretrained(output_dir)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(device=loaded_model.device)
print(tokenizer.decode(loaded_model.generate(**inputs, max_new_tokens=50)[0]))
# clean the output directory after test
shutil.rmtree(output_dir, ignore_errors=True)
48 changes: 47 additions & 1 deletion test/test_cuda/test_moe_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import pytest
import torch
from transformers import AutoConfig, AutoTokenizer, Llama4ForConditionalGeneration
from transformers import AutoConfig, AutoProcessor, AutoTokenizer, Llama4ForConditionalGeneration
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeForConditionalGeneration

from auto_round import AutoRound

Expand Down Expand Up @@ -33,6 +34,21 @@ def setup_llama4():
return model, tokenizer, output_dir, config


@pytest.fixture
def setup_qwen3_vl_moe():
"""Fixture to set up the qwen3_vl_moe model and tokenizer."""
model_name = "/models/Qwen3-VL-30B-A3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name)
config.vision_config.num_hidden_layers = 1
config.text_config.num_hidden_layers = 1
config.num_hidden_layers = 1 # Reduce layers for testing
processor = AutoProcessor.from_pretrained(model_name)
model = Qwen3VLMoeForConditionalGeneration(config)
output_dir = "/tmp/test_quantized_qwen3_vl_moe"
return model, tokenizer, processor, output_dir, config


def quantize_model(model, tokenizer, output_dir, scheme, iters=0):
"""Helper function to quantize the model with the given scheme."""
autoround = AutoRound(
Expand Down Expand Up @@ -103,3 +119,33 @@ def test_llama4(setup_llama4):

# clean the output directory after test
shutil.rmtree(output_dir, ignore_errors=True)


def test_qwen3_vl_moe_mxfp(setup_qwen3_vl_moe):
model, tokenizer, processor, output_dir, config = setup_qwen3_vl_moe
autoround = AutoRound(
model,
tokenizer=tokenizer,
processor=processor,
scheme="MXFP4",
nsamples=2,
seqlen=32,
iters=1,
fp_layers="self_attn,lm_head,mlp.gate",
)
quantized_model, _ = autoround.quantize_and_save(format="auto_round", output_dir=output_dir)
assert quantized_model is not None, "Quantized model should not be None."
loaded_model = Qwen3VLMoeForConditionalGeneration.from_pretrained(output_dir)
loaded_model.to("cuda")
quantized_model.to("cuda")
for n, m in quantized_model.named_modules():
if m.__class__.__name__ == "QuantLinear":
loaded_m = loaded_model.get_submodule(n)
assert (loaded_m.weight_packed == m.weight_packed).all()
# test generation
tokenizer = AutoTokenizer.from_pretrained(output_dir)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(device=loaded_model.device)
print(tokenizer.decode(loaded_model.generate(**inputs, max_new_tokens=50)[0]))
# clean the output directory after test
shutil.rmtree(output_dir, ignore_errors=True)