From 5055f5ea4dab0563caf55872999e25e9798966ee Mon Sep 17 00:00:00 2001 From: Armand Sauzay Date: Wed, 17 Dec 2025 10:09:14 -0800 Subject: [PATCH] =?UTF-8?q?Enable=20direct=20MX4=E2=86=92BF16=20dequantiza?= =?UTF-8?q?tion=20to=20reduce=20memory=20(python=20side)=20(2/2)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: X-link: https://github.com/meta-pytorch/torchrec/pull/3620 X-link: https://github.com/facebookresearch/FBGEMM/pull/2241 python side of changes for Enabling direct MX4→BF16 dequantization to reduce memory Differential Revision: D88913122 --- fbgemm_gpu/fbgemm_gpu/quantize_comm.py | 18 +++--- fbgemm_gpu/fbgemm_gpu/quantize_utils.py | 60 ++++++++++++++++++-- fbgemm_gpu/fbgemm_gpu/triton/quantize.py | 20 ++++--- fbgemm_gpu/test/quantize/mx4_test.py | 70 +++++++++++++++++++++++- 4 files changed, 144 insertions(+), 24 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py index 806dcb4959..25f6e67c3e 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py @@ -25,7 +25,7 @@ fp32_to_hfp8_with_clamp, fp32_to_mx4, hfp8_to_fp32, - mx4_to_fp32, + mx4_to_float, RoundingMode, ) @@ -123,7 +123,7 @@ def _dequantize_tensor( comm_precision: SparseType, ctx: Optional[QuantizationContext] = None, is_fwd: bool = True, - fp8_output_dtype: Optional[SparseType] = None, + output_dtype: Optional[SparseType] = None, ) -> torch.Tensor: if comm_precision == SparseType.FP32: assert quantized_tensor.dtype == torch.float @@ -138,10 +138,8 @@ def _dequantize_tensor( if ctx is not None and ctx.row_dim > 0: row_dim_quant = ctx.row_dim_quant quantized_tensor_2d = quantized_tensor.view((-1, row_dim_quant)) - # use provided fp8_output_dtype or default to FP32 (0) - output_dtype_int = ( - fp8_output_dtype.as_int() if fp8_output_dtype is not None else 0 - ) + # use provided output_dtype or default to FP32 (0) + output_dtype_int = output_dtype.as_int() if output_dtype is not None else 0 dequant_tensor = torch.ops.fbgemm.FP8RowwiseQuantizedToFloat( quantized_tensor_2d, is_fwd, @@ -161,7 +159,7 @@ def _dequantize_tensor( return dequant_tensor.view(-1) elif comm_precision == SparseType.MX4: mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT - return mx4_to_fp32(quantized_tensor, mx_group_size) + return mx4_to_float(quantized_tensor, mx_group_size, output_dtype=output_dtype) else: raise ValueError(f"comm_precision={comm_precision} is not supported") @@ -175,7 +173,7 @@ def __init__( row_dim: Optional[int] = None, is_fwd: bool = True, rounding_mode: Optional[RoundingMode] = None, - fp8_output_dtype: Optional[SparseType] = None, + output_dtype: Optional[SparseType] = None, ) -> None: if loss_scale is not None: if comm_precision not in [SparseType.FP16, SparseType.BF16]: @@ -193,7 +191,7 @@ def __init__( self._is_fwd = is_fwd self._row_dim: int = -1 if row_dim is None else row_dim self._rounding_mode: Optional[RoundingMode] = rounding_mode - self._fp8_output_dtype: Optional[SparseType] = fp8_output_dtype + self._output_dtype: Optional[SparseType] = output_dtype if self._comm_precision == SparseType.MX4: self._row_dim = MX_GROUP_SIZE_DEFAULT if row_dim is None else row_dim self._rounding_mode = ( @@ -229,7 +227,7 @@ def decode( self._comm_precision, ctx, self._is_fwd, - fp8_output_dtype=self._fp8_output_dtype, + output_dtype=self._output_dtype, ) return dequantized_tensor diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_utils.py b/fbgemm_gpu/fbgemm_gpu/quantize_utils.py index 0107b2a953..deff384be4 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_utils.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_utils.py @@ -14,7 +14,9 @@ import fbgemm_gpu -from fbgemm_gpu.triton import dequantize_mx4, quantize_mx4, RoundingMode +from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.triton import quantize_mx4, RoundingMode +from fbgemm_gpu.triton.quantize import triton_dequantize_mx4 from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4 try: @@ -126,25 +128,71 @@ def mx4_to_fp32( ) -> torch.Tensor: """Dequantize an MX4 tensor to FP32 with triton or native cuda impl. + This function is kept for backward compatibility and always returns FP32. + For BF16 output, use mx4_to_float() with output_dtype=SparseType.BF16. + """ + return mx4_to_float( + tensor, + group_size, + use_triton, + ebits, + mbits, + output_dtype=None, # None = FP32 default for backward compatibility + ) + + +def mx4_to_float( + tensor: torch.Tensor, + group_size: int = 32, + use_triton: bool = True, + ebits: int = 2, + mbits: int = 1, + output_dtype: Optional[SparseType] = None, +) -> torch.Tensor: + """Dequantize an MX4 tensor to FP32 or BF16 with triton or native cuda impl. + Args: tensor (torch.Tensor): MX4 packed tensor with total elements (M / 2 + M / groupsize) group_size (int): Compute scale in chunks of group_size. use_triton (bool): If set, use triton quantization, otherwise cuda. ebits (int): Number of exponent bits in target mx4 format. mbits (int): Number of mantissa bits in target mx4 format. + output_dtype (Optional[SparseType]): Output dtype (FP32 or BF16). + Defaults to None (FP32) for backward compatibility. Return: - output: FP32 tensor with total elements (M). + output: Tensor with dtype matching output_dtype and total elements (M). """ + # Validate output_dtype + supported_dtypes = {SparseType.FP32, SparseType.BF16} + if output_dtype is not None and output_dtype not in supported_dtypes: + raise ValueError( + f"output_dtype must be one of {supported_dtypes}, got {output_dtype}. " + f"FP16 is not supported due to potential overflow/underflow with MX4's wide exponent range. " + f"Use BF16 for memory savings with same dynamic range as FP32." + ) + + target_dtype = ( + output_dtype.as_dtype() if output_dtype is not None else torch.float32 + ) + # Accelerated MX4 dequantize is only available on cuda, if input is on cpu, use python. if not tensor.is_cuda and not tensor.is_mtia: - return py_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits) + result = py_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits) + return result.to(target_dtype) if output_dtype is not None else result if use_triton: if tensor.is_mtia: - return mtia_dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits) - return dequantize_mx4(tensor, group_size, ebits=ebits, mbits=mbits) + return mtia_dequantize_mx4( + tensor, group_size, ebits=ebits, mbits=mbits, output_dtype=target_dtype + ) + return triton_dequantize_mx4( + tensor, group_size, ebits=ebits, mbits=mbits, output_dtype=target_dtype + ) else: - return torch.ops.fbgemm.dequantize_mx_cuda(tensor.flatten(), group_size) + output_dtype_int = output_dtype.as_int() if output_dtype is not None else 0 + return torch.ops.fbgemm.dequantize_mx_cuda( + tensor.flatten(), group_size, output_dtype_int + ) def fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor: diff --git a/fbgemm_gpu/fbgemm_gpu/triton/quantize.py b/fbgemm_gpu/fbgemm_gpu/triton/quantize.py index 5200a7cdc1..402d0342ab 100644 --- a/fbgemm_gpu/fbgemm_gpu/triton/quantize.py +++ b/fbgemm_gpu/fbgemm_gpu/triton/quantize.py @@ -575,7 +575,7 @@ def _kernel_dequantize_mx4( # Write final outputs. tl.store( out + output_offset, - scaled_fp32, + scaled_fp32.to(out.dtype.element_ty), # Mask values that are out of this chunk or the main array. mask=(output_offset < OUTPUT_SIZE) & (output_offset < OUTPUT_CHUNK_SIZE * (pid + 1)), @@ -588,10 +588,14 @@ def _kernel_dequantize_mx4( def triton_dequantize_mx4( - a: torch.Tensor, group_size: int = 32, ebits: int = 2, mbits: int = 1 + a: torch.Tensor, + group_size: int = 32, + ebits: int = 2, + mbits: int = 1, + output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """ - Dequantize a tensor from mx4 format to fp32. + Dequantize a tensor from mx4 format to fp32 or bf16. Args: a (Tensor): [M / 2 + M / group_size] MX4 tensor packed into int8 values @@ -599,13 +603,15 @@ def triton_dequantize_mx4( group_size (int): Size of chunks that use the same shared exponent. ebits (int): Number of bits to use for exponent in target mx4 format. mbits (int): Number of bits to use for mantissa in target mx4 format. + output_dtype (torch.dtype): Output dtype (FP32 or BF16). + Defaults to torch.float32 for backward compatibility. Returns: - torch.Tensor: [M, K] dequantized fp32 tensor. + torch.Tensor: [M, K] dequantized tensor in the specified dtype. """ # If given an empty shape, return an empty tensor. if a.numel() == 0: - return torch.empty(a.shape, device=a.device, dtype=torch.float32) + return torch.empty(a.shape, device=a.device, dtype=output_dtype) # View a as 2D for simplicity. orig_shape = a.shape a = a.flatten() @@ -622,9 +628,9 @@ def triton_dequantize_mx4( # Use a lookup table to convert mx4_to_fp_values = get_mx4_lookup_table(ebits, mbits, a.device) - # Create output tensor. + # Create output tensor in target dtype. output_elems = num_groups * group_size - out = torch.empty([output_elems], device=a.device, dtype=torch.float) + out = torch.empty([output_elems], device=a.device, dtype=output_dtype) # Check if we need to use int64 for indexing. use_int64 = num_threads * groups_per_thread * group_size > 2**31 - 1 # Invoke triton dequantization kernel over rows. diff --git a/fbgemm_gpu/test/quantize/mx4_test.py b/fbgemm_gpu/test/quantize/mx4_test.py index fd2a0fb47b..39c7dddbbc 100644 --- a/fbgemm_gpu/test/quantize/mx4_test.py +++ b/fbgemm_gpu/test/quantize/mx4_test.py @@ -13,7 +13,13 @@ import torch -from fbgemm_gpu.quantize_utils import fp32_to_mx4, mx4_to_fp32, RoundingMode +from fbgemm_gpu.quantize_utils import ( + fp32_to_mx4, + mx4_to_float, + mx4_to_fp32, + RoundingMode, +) +from fbgemm_gpu.split_embedding_configs import SparseType from fbgemm_gpu.triton.quantize_ref import py_dequantize_mx4, py_quantize_mx4 from hypothesis import given, settings, Verbosity @@ -287,6 +293,68 @@ def test_mx4_cases( # I give quite a bit of wiggle room to make sure this isnt flaky. torch.testing.assert_close(input, mx_dequantized, rtol=1.0, atol=magnitude / 2) + @unittest.skipIf(*gpu_unavailable) + # pyre-fixme[56]: + @given( + output_dtype=st.sampled_from([None, SparseType.FP32, SparseType.BF16]), + use_triton=st.booleans(), + ) + @settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None) + def test_mx4_to_float_correctness( + self, output_dtype: SparseType | None, use_triton: bool + ) -> None: + """Test MX4 dequantization correctness for FP32 and BF16 output dtypes. + + Validates that mx4_to_float produces correct outputs: + - dtype matches expected (None -> FP32, SparseType.FP32 -> FP32, SparseType.BF16 -> BF16) + - FP32 output matches against cpu reference + - BF16 output matches against FP32 -> BF16 conversion + """ + device = torch.device(torch.accelerator.current_accelerator() or "cuda") + input_data = torch.randn([128, 1024], device=device, dtype=torch.float32) + + quantized = fp32_to_mx4( + input_data, + group_size=32, + use_triton=use_triton, + rounding_mode=RoundingMode.floor, + ) + output = mx4_to_float( + quantized, + group_size=32, + use_triton=use_triton, + output_dtype=output_dtype, + ) + output = output.reshape(input_data.shape) # CUDA returns flattened tensor + + # validate dtype + expected_dtype = output_dtype.as_dtype() if output_dtype else torch.float32 + self.assertEqual(output.dtype, expected_dtype) + + # validate fp32 against cpu reference + if expected_dtype == torch.float32: + input_cpu = input_data.cpu() + quantized_cpu = py_quantize_mx4( + input_cpu, group_size=32, rounding_mode=RoundingMode.floor + ) + output_cpu = py_dequantize_mx4(quantized_cpu, group_size=32) + assert check_diff_quantize(input_cpu, output_cpu, output.cpu()) + + # validate bf16 matches fp32->bf16 conversion + elif expected_dtype == torch.bfloat16: + output_fp32 = mx4_to_float( + quantized, + group_size=32, + use_triton=use_triton, + output_dtype=None, # Get FP32 output + ) + output_fp32_as_bf16 = output_fp32.reshape(input_data.shape).to( + torch.bfloat16 + ) + torch.testing.assert_close( + output.cpu(), output_fp32_as_bf16.cpu(), rtol=0.0, atol=0.0 + ) + # pyre-fixme[56]: @unittest.skipIf( not (