Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions fbgemm_gpu/fbgemm_gpu/quantize_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
fp32_to_hfp8_with_clamp,
fp32_to_mx4,
hfp8_to_fp32,
mx4_to_fp32,
mx4_to_float,
RoundingMode,
)

Expand Down Expand Up @@ -123,7 +123,7 @@ def _dequantize_tensor(
comm_precision: SparseType,
ctx: Optional[QuantizationContext] = None,
is_fwd: bool = True,
fp8_output_dtype: Optional[SparseType] = None,
output_dtype: Optional[SparseType] = None,
) -> torch.Tensor:
if comm_precision == SparseType.FP32:
assert quantized_tensor.dtype == torch.float
Expand All @@ -138,10 +138,8 @@ def _dequantize_tensor(
if ctx is not None and ctx.row_dim > 0:
row_dim_quant = ctx.row_dim_quant
quantized_tensor_2d = quantized_tensor.view((-1, row_dim_quant))
# use provided fp8_output_dtype or default to FP32 (0)
output_dtype_int = (
fp8_output_dtype.as_int() if fp8_output_dtype is not None else 0
)
# use provided output_dtype or default to FP32 (0)
output_dtype_int = output_dtype.as_int() if output_dtype is not None else 0
dequant_tensor = torch.ops.fbgemm.FP8RowwiseQuantizedToFloat(
quantized_tensor_2d,
is_fwd,
Expand All @@ -161,7 +159,7 @@ def _dequantize_tensor(
return dequant_tensor.view(-1)
elif comm_precision == SparseType.MX4:
mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT
return mx4_to_fp32(quantized_tensor, mx_group_size)
return mx4_to_float(quantized_tensor, mx_group_size, output_dtype=output_dtype)
else:
raise ValueError(f"comm_precision={comm_precision} is not supported")

Expand All @@ -175,7 +173,7 @@ def __init__(
row_dim: Optional[int] = None,
is_fwd: bool = True,
rounding_mode: Optional[RoundingMode] = None,
fp8_output_dtype: Optional[SparseType] = None,
output_dtype: Optional[SparseType] = None,
) -> None:
if loss_scale is not None:
if comm_precision not in [SparseType.FP16, SparseType.BF16]:
Expand All @@ -193,7 +191,7 @@ def __init__(
self._is_fwd = is_fwd
self._row_dim: int = -1 if row_dim is None else row_dim
self._rounding_mode: Optional[RoundingMode] = rounding_mode
self._fp8_output_dtype: Optional[SparseType] = fp8_output_dtype
self._output_dtype: Optional[SparseType] = output_dtype
if self._comm_precision == SparseType.MX4:
self._row_dim = MX_GROUP_SIZE_DEFAULT if row_dim is None else row_dim
self._rounding_mode = (
Expand Down Expand Up @@ -229,7 +227,7 @@ def decode(
self._comm_precision,
ctx,
self._is_fwd,
fp8_output_dtype=self._fp8_output_dtype,
output_dtype=self._output_dtype,
)
return dequantized_tensor

Expand Down
60 changes: 54 additions & 6 deletions fbgemm_gpu/fbgemm_gpu/quantize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

import fbgemm_gpu

from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4, RoundingMode
from fbgemm_gpu.split_embedding_configs import SparseType
from fbgemm_gpu.triton import quantize_mx4, RoundingMode
from fbgemm_gpu.triton.quantize import triton_dequantize_mx4
from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4

try:
Expand Down Expand Up @@ -126,25 +128,71 @@ def mx4_to_fp32(
) -> torch.Tensor:
"""Dequantize an MX4 tensor to FP32 with triton or native cuda impl.

This function is kept for backward compatibility and always returns FP32.
For BF16 output, use mx4_to_float() with output_dtype=SparseType.BF16.
"""
return mx4_to_float(
tensor,
group_size,
use_triton,
ebits,
mbits,
output_dtype=None, # None = FP32 default for backward compatibility
)


def mx4_to_float(
tensor: torch.Tensor,
group_size: int = 32,
use_triton: bool = True,
ebits: int = 2,
mbits: int = 1,
output_dtype: Optional[SparseType] = None,
) -> torch.Tensor:
"""Dequantize an MX4 tensor to FP32 or BF16 with triton or native cuda impl.

Args:
tensor (torch.Tensor): MX4 packed tensor with total elements (M / 2 + M / groupsize)
group_size (int): Compute scale in chunks of group_size.
use_triton (bool): If set, use triton quantization, otherwise cuda.
ebits (int): Number of exponent bits in target mx4 format.
mbits (int): Number of mantissa bits in target mx4 format.
output_dtype (Optional[SparseType]): Output dtype (FP32 or BF16).
Defaults to None (FP32) for backward compatibility.

Return:
output: FP32 tensor with total elements (M).
output: Tensor with dtype matching output_dtype and total elements (M).
"""
# Validate output_dtype
supported_dtypes = {SparseType.FP32, SparseType.BF16}
if output_dtype is not None and output_dtype not in supported_dtypes:
raise ValueError(
f"output_dtype must be one of {supported_dtypes}, got {output_dtype}. "
f"FP16 is not supported due to potential overflow/underflow with MX4's wide exponent range. "
f"Use BF16 for memory savings with same dynamic range as FP32."
)

target_dtype = (
output_dtype.as_dtype() if output_dtype is not None else torch.float32
)

# Accelerated MX4 dequantize is only available on cuda, if input is on cpu, use python.
if not tensor.is_cuda and not tensor.is_mtia:
return py_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
result = py_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
return result.to(target_dtype) if output_dtype is not None else result
if use_triton:
if tensor.is_mtia:
return mtia_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
return dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits)
return mtia_dequantize_mx4(
tensor, group_size, ebits=ebits, mbits=mbits, output_dtype=target_dtype
)
return triton_dequantize_mx4(
tensor, group_size, ebits=ebits, mbits=mbits, output_dtype=target_dtype
)
else:
return torch.ops.fbgemm.dequantize_mx_cuda(tensor.flatten(), group_size)
output_dtype_int = output_dtype.as_int() if output_dtype is not None else 0
return torch.ops.fbgemm.dequantize_mx_cuda(
tensor.flatten(), group_size, output_dtype_int
)


def fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
Expand Down
20 changes: 13 additions & 7 deletions fbgemm_gpu/fbgemm_gpu/triton/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def _kernel_dequantize_mx4(
# Write final outputs.
tl.store(
out + output_offset,
scaled_fp32,
scaled_fp32.to(out.dtype.element_ty),
# Mask values that are out of this chunk or the main array.
mask=(output_offset < OUTPUT_SIZE)
& (output_offset < OUTPUT_CHUNK_SIZE * (pid + 1)),
Expand All @@ -588,24 +588,30 @@ def _kernel_dequantize_mx4(


def triton_dequantize_mx4(
a: torch.Tensor, group_size: int = 32, ebits: int = 2, mbits: int = 1
a: torch.Tensor,
group_size: int = 32,
ebits: int = 2,
mbits: int = 1,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Dequantize a tensor from mx4 format to fp32.
Dequantize a tensor from mx4 format to fp32 or bf16.

Args:
a (Tensor): [M / 2 + M / group_size] MX4 tensor packed into int8 values
with group exponents attached to end of each row.
group_size (int): Size of chunks that use the same shared exponent.
ebits (int): Number of bits to use for exponent in target mx4 format.
mbits (int): Number of bits to use for mantissa in target mx4 format.
output_dtype (torch.dtype): Output dtype (FP32 or BF16).
Defaults to torch.float32 for backward compatibility.

Returns:
torch.Tensor: [M, K] dequantized fp32 tensor.
torch.Tensor: [M, K] dequantized tensor in the specified dtype.
"""
# If given an empty shape, return an empty tensor.
if a.numel() == 0:
return torch.empty(a.shape, device=a.device, dtype=torch.float32)
return torch.empty(a.shape, device=a.device, dtype=output_dtype)
# View a as 2D for simplicity.
orig_shape = a.shape
a = a.flatten()
Expand All @@ -622,9 +628,9 @@ def triton_dequantize_mx4(
# Use a lookup table to convert
mx4_to_fp_values = get_mx4_lookup_table(ebits, mbits, a.device)

# Create output tensor.
# Create output tensor in target dtype.
output_elems = num_groups * group_size
out = torch.empty([output_elems], device=a.device, dtype=torch.float)
out = torch.empty([output_elems], device=a.device, dtype=output_dtype)
# Check if we need to use int64 for indexing.
use_int64 = num_threads * groups_per_thread * group_size > 2**31 - 1
# Invoke triton dequantization kernel over rows.
Expand Down
70 changes: 69 additions & 1 deletion fbgemm_gpu/test/quantize/mx4_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@

import torch

from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32, RoundingMode
from fbgemm_gpu.quantize_utils import (
fp32_to_mx4,
mx4_to_float,
mx4_to_fp32,
RoundingMode,
)
from fbgemm_gpu.split_embedding_configs import SparseType
from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4

from hypothesis import given, settings, Verbosity
Expand Down Expand Up @@ -287,6 +293,68 @@ def test_mx4_cases(
# I give quite a bit of wiggle room to make sure this isnt flaky.
torch.testing.assert_close(input, mx_dequantized, rtol=1.0, atol=magnitude / 2)

@unittest.skipIf(*gpu_unavailable)
# pyre-fixme[56]:
@given(
output_dtype=st.sampled_from([None, SparseType.FP32, SparseType.BF16]),
use_triton=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None)
def test_mx4_to_float_correctness(
self, output_dtype: SparseType | None, use_triton: bool
) -> None:
"""Test MX4 dequantization correctness for FP32 and BF16 output dtypes.

Validates that mx4_to_float produces correct outputs:
- dtype matches expected (None -> FP32, SparseType.FP32 -> FP32, SparseType.BF16 -> BF16)
- FP32 output matches against cpu reference
- BF16 output matches against FP32 -> BF16 conversion
"""
device = torch.device(torch.accelerator.current_accelerator() or "cuda")
input_data = torch.randn([128, 1024], device=device, dtype=torch.float32)

quantized = fp32_to_mx4(
input_data,
group_size=32,
use_triton=use_triton,
rounding_mode=RoundingMode.floor,
)
output = mx4_to_float(
quantized,
group_size=32,
use_triton=use_triton,
output_dtype=output_dtype,
)
output = output.reshape(input_data.shape) # CUDA returns flattened tensor

# validate dtype
expected_dtype = output_dtype.as_dtype() if output_dtype else torch.float32
self.assertEqual(output.dtype, expected_dtype)

# validate fp32 against cpu reference
if expected_dtype == torch.float32:
input_cpu = input_data.cpu()
quantized_cpu = py_quantize_mx4(
input_cpu, group_size=32, rounding_mode=RoundingMode.floor
)
output_cpu = py_dequantize_mx4(quantized_cpu, group_size=32)
assert check_diff_quantize(input_cpu, output_cpu, output.cpu())

# validate bf16 matches fp32->bf16 conversion
elif expected_dtype == torch.bfloat16:
output_fp32 = mx4_to_float(
quantized,
group_size=32,
use_triton=use_triton,
output_dtype=None, # Get FP32 output
)
output_fp32_as_bf16 = output_fp32.reshape(input_data.shape).to(
torch.bfloat16
)
torch.testing.assert_close(
output.cpu(), output_fp32_as_bf16.cpu(), rtol=0.0, atol=0.0
)

# pyre-fixme[56]:
@unittest.skipIf(
not (
Expand Down
Loading