diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu index c2a16eacea..f3c5538eb4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu @@ -158,12 +158,22 @@ std::tuple dispatch_fmha_fwd_meta( c10::SymInt window_size_left, c10::SymInt window_size_right, bool bottom_right) { - auto output = at::empty_like(q); + auto out_dtype = q.scalar_type(); + if(q.scalar_type() == at::kFloat8_e4m3fn) { + // Output is BF16 when input is FP8 + out_dtype = at::kBFloat16; + } + auto output = at::empty_like(q, q.options().dtype(out_dtype)); bool k_is_varlen = max_seq_len_q.has_value(); auto SQ = k_is_varlen ? q.sym_size(0) : q.sym_size(1); auto H_Q = k_is_varlen ? q.sym_size(1) : q.sym_size(2); auto B = k_is_varlen ? 1 : q.sym_size(0); - auto logsumexp = q.new_empty_symint({B, H_Q, SQ}, q.options()); + if(k_is_varlen) { + // Tweak storage offset of output + auto storage_offset = q.sym_size(-1) * (*max_seq_len_q) * H_Q; + output = output.as_strided_symint(output.sym_sizes(), output.sym_strides(), storage_offset); + } + auto logsumexp = q.new_empty_symint({B, H_Q, SQ}, q.options().dtype(at::kFloat)); return std::make_tuple(output, logsumexp); } diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu index 638b6a495d..875fa5ca2b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_gen_impl.cu @@ -425,15 +425,27 @@ std::tuple dispatch_fmha_gen_fwd_meta( int64_t window_right, int64_t split_k_size ) { - // Return tuple matching the operator signature: (output, lse) - at::Tensor output = at::empty_like(q); - // LSE should have shape [B, num_splits, H] - int b = q.size(0); - int h = q.size(2); - // For meta, just create a dummy LSE with single split - at::Tensor lse = at::empty( - {b, 1, h}, - at::TensorOptions().dtype(at::kFloat).device(at::kMeta)); + auto b = q.sym_size(0); + auto sq = q.sym_size(1); + auto h = q.sym_size(2); + auto d = q.sym_size(3); + assert(sq == 1); + auto sk = k.sym_size(1); + auto h_k = k.sym_size(2); + + c10::SymInt split_kv = 1; + if (split_k_size > 0) { + split_kv = (sk + split_k_size - 1) / split_k_size; + } + + auto out_dtype = q.scalar_type(); + if(q.scalar_type() == at::kFloat8_e4m3fn) { + // Output is BF16 when input is FP8 + out_dtype = at::kBFloat16; + } + + auto output = at::empty_symint({b, h, split_kv, d}, q.options().dtype(out_dtype)); + auto lse = at::empty_symint({b, split_kv, h}, q.options().dtype(at::kFloat)); return std::make_tuple(output, lse); } diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py index 6ff8760077..cda5b0e490 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py @@ -8,6 +8,7 @@ import os import random import unittest +from contextlib import nullcontext from typing import cast, Optional import torch @@ -292,6 +293,7 @@ def _execute_cutlass_blackwell_attn_decode( window_size: tuple[int, int], sm_scale: Optional[float], use_full_seqlen: bool = False, + use_compile: bool = False, ) -> None: device = torch.accelerator.current_accelerator() assert device is not None @@ -385,23 +387,29 @@ def _execute_cutlass_blackwell_attn_decode( print(f"KV padding (constant): {kv_padding}") print(f"seqlen_kv (variable): {_seqused_k}") # Run decode-specific kernel - out, lse = cutlass_blackwell_fmha_decode_forward( - q, - k, - v, - seqlen_kv=_seqused_k, - cu_seqlens_q=None, - cu_seqlens_k=None, - max_seq_len_q=None, - max_seq_len_k=None, - softmax_scale=sm_scale, - causal=causal, - window_left=window_size[0], - window_right=window_size[1], - bottom_right=True, - split_k_size=0, - use_heuristic=False, - ) + func_to_test = cutlass_blackwell_fmha_decode_forward + forward_test_ctx = nullcontext() + if use_compile: + func_to_test = torch.compile(func_to_test, fullgraph=True) + forward_test_ctx = torch.no_grad() + with forward_test_ctx: + out, lse = func_to_test( + q, + k, + v, + seqlen_kv=_seqused_k, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seq_len_q=None, + max_seq_len_k=None, + softmax_scale=sm_scale, + causal=causal, + window_left=window_size[0], + window_right=window_size[1], + bottom_right=True, + split_k_size=0, + use_heuristic=False, + ) # Output is [B, 1, H, 1, D] - squeeze num_splits dimension out = out.squeeze(3) # [B, 1, H, D] @@ -523,63 +531,68 @@ def _execute_cutlass_blackwell_attn_dense( # Run tested kernel func_to_test = cutlass_blackwell_fmha_func + forward_test_ctx = nullcontext() if use_compile: func_to_test = torch.compile(func_to_test, fullgraph=True) - if is_paged: - assert k_paged is not None and v_paged is not None - out_paged = func_to_test( - q, - k_paged, - v_paged, - causal=causal, - window_size=window_size, - seqlen_kv=seqlen_kv, - page_table=page_table, - seqlen_k=seqlen_k, - deterministic=deterministic, - softmax_scale=sm_scale, - ) - - out = func_to_test( - q, - k, - v, - causal=causal, - window_size=window_size, - seqlen_kv=seqlen_kv, - page_table=None, - seqlen_k=seqlen_k, - deterministic=deterministic, - softmax_scale=sm_scale, - ) - - if DEBUG: - print("cutlass_blackwell_fmha_func completed successfully!") - - # Follow FlashAttention's numerical evaluation - # Compare outputs - if is_paged: - # Compare paged output with both reference and non paged output - self._allclose(out_paged, out_ref, out_pt) - self._allclose(out_paged, out, out_pt) - else: - self._allclose(out, out_ref, out_pt) + if fwd_only: + forward_test_ctx = torch.no_grad() + + with forward_test_ctx: + if is_paged: + assert k_paged is not None and v_paged is not None + out_paged = func_to_test( + q, + k_paged, + v_paged, + causal=causal, + window_size=window_size, + seqlen_kv=seqlen_kv, + page_table=page_table, + seqlen_k=seqlen_k, + deterministic=deterministic, + softmax_scale=sm_scale, + ) - if deterministic: - # Rerun the test. The outputs must be bit-wise exact - out_d = func_to_test( + out = func_to_test( q, - cast(torch.Tensor, k_paged) if is_paged else k, - cast(torch.Tensor, v_paged) if is_paged else v, + k, + v, causal=causal, window_size=window_size, seqlen_kv=seqlen_kv, - page_table=page_table if is_paged else None, + page_table=None, seqlen_k=seqlen_k, deterministic=deterministic, softmax_scale=sm_scale, ) - assert torch.equal(out, out_d) + + if DEBUG: + print("cutlass_blackwell_fmha_func completed successfully!") + + # Follow FlashAttention's numerical evaluation + # Compare outputs + if is_paged: + # Compare paged output with both reference and non paged output + self._allclose(out_paged, out_ref, out_pt) + self._allclose(out_paged, out, out_pt) + else: + self._allclose(out, out_ref, out_pt) + + if deterministic: + # Rerun the test. The outputs must be bit-wise exact + out_d = func_to_test( + q, + cast(torch.Tensor, k_paged) if is_paged else k, + cast(torch.Tensor, v_paged) if is_paged else v, + causal=causal, + window_size=window_size, + seqlen_kv=seqlen_kv, + page_table=page_table if is_paged else None, + seqlen_k=seqlen_k, + deterministic=deterministic, + softmax_scale=sm_scale, + ) + assert torch.equal(out, out_d) if fwd_only: return @@ -730,69 +743,74 @@ def _execute_cutlass_blackwell_attn_varlen( ) func_to_test = cutlass_blackwell_fmha_func + forward_test_ctx = nullcontext() if use_compile: func_to_test = torch.compile(func_to_test, fullgraph=True) - if is_paged: - assert k_paged is not None and v_paged is not None - out_unpad_paged = func_to_test( - q_unpad, - k_paged, - v_paged, - causal=causal, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seq_len_q=max_seqlen_q, - max_seq_len_k=max_seqlen_k, - page_table=page_table, - window_size=window_size, - deterministic=deterministic, - softmax_scale=sm_scale, - ) - out_paged = output_pad_fn(out_unpad_paged) - - out_unpad = func_to_test( - q_unpad, - k_unpad, - v_unpad, - causal=causal, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seq_len_q=max_seqlen_q, - max_seq_len_k=max_seqlen_k, - page_table=None, - window_size=window_size, - deterministic=deterministic, - softmax_scale=sm_scale, - ) - out = output_pad_fn(out_unpad) - - # Follow FlashAttention's numerical evaluation - # Compare outputs - if is_paged: - # Compare paged output with both reference and non paged output - self._allclose(out_paged, out_ref, out_pt) - self._allclose(out_paged, out, out_pt) - else: - self._allclose(out, out_ref, out_pt) + if fwd_only: + forward_test_ctx = torch.no_grad() + + with forward_test_ctx: + if is_paged: + assert k_paged is not None and v_paged is not None + out_unpad_paged = func_to_test( + q_unpad, + k_paged, + v_paged, + causal=causal, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seq_len_q=max_seqlen_q, + max_seq_len_k=max_seqlen_k, + page_table=page_table, + window_size=window_size, + deterministic=deterministic, + softmax_scale=sm_scale, + ) + out_paged = output_pad_fn(out_unpad_paged) - if deterministic: - # Rerun the test. The outputs must be bit-wise exact - out_unpad_d = func_to_test( + out_unpad = func_to_test( q_unpad, - cast(torch.Tensor, k_paged) if is_paged else k_unpad, - cast(torch.Tensor, v_paged) if is_paged else v_unpad, + k_unpad, + v_unpad, causal=causal, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seq_len_q=max_seqlen_q, max_seq_len_k=max_seqlen_k, - page_table=page_table, + page_table=None, window_size=window_size, deterministic=deterministic, softmax_scale=sm_scale, ) - out_d = output_pad_fn(out_unpad_d) - assert torch.equal(out, out_d) + out = output_pad_fn(out_unpad) + + # Follow FlashAttention's numerical evaluation + # Compare outputs + if is_paged: + # Compare paged output with both reference and non paged output + self._allclose(out_paged, out_ref, out_pt) + self._allclose(out_paged, out, out_pt) + else: + self._allclose(out, out_ref, out_pt) + + if deterministic: + # Rerun the test. The outputs must be bit-wise exact + out_unpad_d = func_to_test( + q_unpad, + cast(torch.Tensor, k_paged) if is_paged else k_unpad, + cast(torch.Tensor, v_paged) if is_paged else v_unpad, + causal=causal, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seq_len_q=max_seqlen_q, + max_seq_len_k=max_seqlen_k, + page_table=page_table, + window_size=window_size, + deterministic=deterministic, + softmax_scale=sm_scale, + ) + out_d = output_pad_fn(out_unpad_d) + assert torch.equal(out, out_d) if fwd_only: return @@ -1531,7 +1549,6 @@ def test_compile( kv_heads = 2 if is_mqa else q_heads batch_size = 2 seqlen_k = 128 - kv_heads = 2 head_dim = 128 dtype = torch.bfloat16 causal = True @@ -1554,6 +1571,7 @@ def test_compile( dtype=dtype, window_size=window_size, sm_scale=None, + use_compile=True, ) return @@ -1572,6 +1590,7 @@ def test_compile( deterministic=deterministic, sm_scale=sm_scale, is_paged=False, + use_compile=True, ) @skip_cuda_lt_sm100