Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from QEfficient.compile.compile_helper import compile
from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline
from QEfficient.diffusers.pipelines.qwen_image.pipeline_qwenimage import QEFFQwenImagePipeline
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.peft import QEffAutoPeftModelForCausalLM
Expand All @@ -55,6 +56,7 @@
"QEFFAutoModelForSpeechSeq2Seq",
"QEFFCommonLoader",
"QEffFluxPipeline",
"QEFFQwenImagePipeline",
]


Expand Down
13 changes: 13 additions & 0 deletions QEfficient/diffusers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
# -----------------------------------------------------------------------------

from diffusers.models.attention_processor import Attention
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, RMSNorm
from diffusers.models.transformers.transformer_flux import (
FluxAttention,
Expand All @@ -13,6 +14,10 @@
FluxTransformer2DModel,
FluxTransformerBlock,
)
from diffusers.models.transformers.transformer_qwenimage import (
QwenDoubleStreamAttnProcessor2_0,
QwenImageTransformer2DModel,
)
from torch import nn

from QEfficient.base.pytorch_transforms import ModuleMappingTransform
Expand All @@ -29,6 +34,11 @@
QEffFluxTransformer2DModel,
QEffFluxTransformerBlock,
)
from QEfficient.diffusers.models.transformers.transformer_qwenimage import (
QEffQwenDoubleStreamAttnProcessor2_0,
QEffQwenImageAttention,
QEffQwenImageTransformer2DModel,
)


class CustomOpsTransform(ModuleMappingTransform):
Expand All @@ -45,6 +55,9 @@ class AttentionTransform(ModuleMappingTransform):
FluxTransformer2DModel: QEffFluxTransformer2DModel,
FluxAttention: QEffFluxAttention,
FluxAttnProcessor: QEffFluxAttnProcessor,
QwenImageTransformer2DModel: QEffQwenImageTransformer2DModel,
QwenDoubleStreamAttnProcessor2_0: QEffQwenDoubleStreamAttnProcessor2_0,
Attention: QEffQwenImageAttention,
}


Expand Down
409 changes: 409 additions & 0 deletions QEfficient/diffusers/models/transformers/transformer_qwenimage.py

Large diffs are not rendered by default.

113 changes: 113 additions & 0 deletions QEfficient/diffusers/pipelines/pipeline_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,3 +479,116 @@ def compile(self, specializations: List[Dict], **compiler_options) -> None:
**compiler_options: Additional compiler options (e.g., num_cores, aic_num_of_activations)
"""
self._compile(specializations=specializations, **compiler_options)


class QEffQwenImageTransformer2DModel(QEFFBaseModel):
_pytorch_transforms = [AttentionTransform, CustomOpsTransform]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

"""
QEffQwenImageTransformer2DModel is a wrapper class for QwenImage Transformer2D models that provides ONNX export and compilation capabilities.

This class extends QEFFBaseModel to handle QwenImage Transformer2D models with specific transformations and optimizations
for efficient inference on Qualcomm AI hardware. It is designed for the QwenImage architecture that uses
transformer-based diffusion models with unique latent packing and attention mechanisms.
"""

def __init__(self, model: nn.modules):
super().__init__(model)
self.model = model

def get_onnx_config(self):
bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE

# For testing purpose I have set this to constant values from the original models
latent_seq_len = 6032
text_seq_len = 126
hidden_dim = 64
encoder_hidden_dim = 3584
example_inputs = {
"hidden_states": torch.randn(bs, latent_seq_len, hidden_dim, dtype=torch.float32),
"encoder_hidden_states": torch.randn(bs, text_seq_len, encoder_hidden_dim, dtype=torch.float32),
"encoder_hidden_states_mask": torch.ones(bs, text_seq_len, dtype=torch.int64),
"timestep": torch.tensor([1000.0], dtype=torch.float32),
"frame": torch.tensor([1], dtype=torch.int64),
"height": torch.tensor([58], dtype=torch.int64),
"width": torch.tensor([104], dtype=torch.int64),
"txt_seq_lens": torch.tensor([126], dtype=torch.int64),
}

output_names = ["output"]

dynamic_axes = {
"hidden_states": {0: "batch_size", 1: "latent_seq_len"},
"encoder_hidden_states": {0: "batch_size", 1: "text_seq_len"},
"encoder_hidden_states_mask": {0: "batch_size", 1: "text_seq_len"},
}

return example_inputs, dynamic_axes, output_names

def export(
self,
inputs,
output_names,
dynamic_axes,
export_dir=None,
export_kwargs=None,
):
return self._export(
example_inputs=inputs,
output_names=output_names,
dynamic_axes=dynamic_axes,
export_dir=export_dir,
export_kwargs=export_kwargs,
)

def get_specializations(
self,
batch_size: int,
latent_seq_len: int,
text_seq_len: int,
):
specializations = [
{
"batch_size": batch_size,
"latent_seq_len": latent_seq_len,
"text_seq_len": text_seq_len,
}
]

return specializations

def compile(
self,
compile_dir,
compile_only,
specializations,
convert_to_fp16,
mxfp6_matmul,
mdp_ts_num_devices,
aic_num_cores,
custom_io,
**compiler_options,
) -> str:
return self._compile(
compile_dir=compile_dir,
compile_only=compile_only,
specializations=specializations,
convert_to_fp16=convert_to_fp16,
mxfp6_matmul=mxfp6_matmul,
mdp_ts_num_devices=mdp_ts_num_devices,
aic_num_cores=aic_num_cores,
custom_io=custom_io,
**compiler_options,
)

@property
def model_name(self) -> str:
mname = self.model.__class__.__name__
if mname.startswith("QEff") or mname.startswith("QEFF"):
mname = mname[4:]
return mname

@property
def get_model_config(self) -> dict:
return self.model.config.__dict__
6 changes: 6 additions & 0 deletions QEfficient/diffusers/pipelines/qwen_image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------
Loading
Loading