Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,22 @@ std::tuple<at::Tensor, at::Tensor> dispatch_fmha_fwd_meta(
c10::SymInt window_size_left,
c10::SymInt window_size_right,
bool bottom_right) {
auto output = at::empty_like(q);
auto out_dtype = q.scalar_type();
if(q.scalar_type() == at::kFloat8_e4m3fn) {
// Output is BF16 when input is FP8
out_dtype = at::kBFloat16;
}
auto output = at::empty_like(q, q.options().dtype(out_dtype));
bool k_is_varlen = max_seq_len_q.has_value();
auto SQ = k_is_varlen ? q.sym_size(0) : q.sym_size(1);
auto H_Q = k_is_varlen ? q.sym_size(1) : q.sym_size(2);
auto B = k_is_varlen ? 1 : q.sym_size(0);
auto logsumexp = q.new_empty_symint({B, H_Q, SQ}, q.options());
if(k_is_varlen) {
// Tweak storage offset of output
auto storage_offset = q.sym_size(-1) * (*max_seq_len_q) * H_Q;
output = output.as_strided_symint(output.sym_sizes(), output.sym_strides(), storage_offset);
}
auto logsumexp = q.new_empty_symint({B, H_Q, SQ}, q.options().dtype(at::kFloat));
return std::make_tuple(output, logsumexp);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,15 +425,27 @@ std::tuple<at::Tensor, at::Tensor> dispatch_fmha_gen_fwd_meta(
int64_t window_right,
int64_t split_k_size
) {
// Return tuple matching the operator signature: (output, lse)
at::Tensor output = at::empty_like(q);
// LSE should have shape [B, num_splits, H]
int b = q.size(0);
int h = q.size(2);
// For meta, just create a dummy LSE with single split
at::Tensor lse = at::empty(
{b, 1, h},
at::TensorOptions().dtype(at::kFloat).device(at::kMeta));
auto b = q.sym_size(0);
auto sq = q.sym_size(1);
auto h = q.sym_size(2);
auto d = q.sym_size(3);
assert(sq == 1);
auto sk = k.sym_size(1);
auto h_k = k.sym_size(2);

c10::SymInt split_kv = 1;
if (split_k_size > 0) {
split_kv = (sk + split_k_size - 1) / split_k_size;
}

auto out_dtype = q.scalar_type();
if(q.scalar_type() == at::kFloat8_e4m3fn) {
// Output is BF16 when input is FP8
out_dtype = at::kBFloat16;
}

auto output = at::empty_symint({b, h, split_kv, d}, q.options().dtype(out_dtype));
auto lse = at::empty_symint({b, split_kv, h}, q.options().dtype(at::kFloat));
return std::make_tuple(output, lse);
}

Expand Down
247 changes: 133 additions & 114 deletions fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import random
import unittest
from contextlib import nullcontext
from typing import cast, Optional

import torch
Expand Down Expand Up @@ -292,6 +293,7 @@ def _execute_cutlass_blackwell_attn_decode(
window_size: tuple[int, int],
sm_scale: Optional[float],
use_full_seqlen: bool = False,
use_compile: bool = False,
) -> None:
device = torch.accelerator.current_accelerator()
assert device is not None
Expand Down Expand Up @@ -385,23 +387,29 @@ def _execute_cutlass_blackwell_attn_decode(
print(f"KV padding (constant): {kv_padding}")
print(f"seqlen_kv (variable): {_seqused_k}")
# Run decode-specific kernel
out, lse = cutlass_blackwell_fmha_decode_forward(
q,
k,
v,
seqlen_kv=_seqused_k,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seq_len_q=None,
max_seq_len_k=None,
softmax_scale=sm_scale,
causal=causal,
window_left=window_size[0],
window_right=window_size[1],
bottom_right=True,
split_k_size=0,
use_heuristic=False,
)
func_to_test = cutlass_blackwell_fmha_decode_forward
forward_test_ctx = nullcontext()
if use_compile:
func_to_test = torch.compile(func_to_test, fullgraph=True)
forward_test_ctx = torch.no_grad()
with forward_test_ctx:
out, lse = func_to_test(
q,
k,
v,
seqlen_kv=_seqused_k,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seq_len_q=None,
max_seq_len_k=None,
softmax_scale=sm_scale,
causal=causal,
window_left=window_size[0],
window_right=window_size[1],
bottom_right=True,
split_k_size=0,
use_heuristic=False,
)

# Output is [B, 1, H, 1, D] - squeeze num_splits dimension
out = out.squeeze(3) # [B, 1, H, D]
Expand Down Expand Up @@ -523,63 +531,68 @@ def _execute_cutlass_blackwell_attn_dense(

# Run tested kernel
func_to_test = cutlass_blackwell_fmha_func
forward_test_ctx = nullcontext()
if use_compile:
func_to_test = torch.compile(func_to_test, fullgraph=True)
if is_paged:
assert k_paged is not None and v_paged is not None
out_paged = func_to_test(
q,
k_paged,
v_paged,
causal=causal,
window_size=window_size,
seqlen_kv=seqlen_kv,
page_table=page_table,
seqlen_k=seqlen_k,
deterministic=deterministic,
softmax_scale=sm_scale,
)

out = func_to_test(
q,
k,
v,
causal=causal,
window_size=window_size,
seqlen_kv=seqlen_kv,
page_table=None,
seqlen_k=seqlen_k,
deterministic=deterministic,
softmax_scale=sm_scale,
)

if DEBUG:
print("cutlass_blackwell_fmha_func completed successfully!")

# Follow FlashAttention's numerical evaluation
# Compare outputs
if is_paged:
# Compare paged output with both reference and non paged output
self._allclose(out_paged, out_ref, out_pt)
self._allclose(out_paged, out, out_pt)
else:
self._allclose(out, out_ref, out_pt)
if fwd_only:
forward_test_ctx = torch.no_grad()

with forward_test_ctx:
if is_paged:
assert k_paged is not None and v_paged is not None
out_paged = func_to_test(
q,
k_paged,
v_paged,
causal=causal,
window_size=window_size,
seqlen_kv=seqlen_kv,
page_table=page_table,
seqlen_k=seqlen_k,
deterministic=deterministic,
softmax_scale=sm_scale,
)

if deterministic:
# Rerun the test. The outputs must be bit-wise exact
out_d = func_to_test(
out = func_to_test(
q,
cast(torch.Tensor, k_paged) if is_paged else k,
cast(torch.Tensor, v_paged) if is_paged else v,
k,
v,
causal=causal,
window_size=window_size,
seqlen_kv=seqlen_kv,
page_table=page_table if is_paged else None,
page_table=None,
seqlen_k=seqlen_k,
deterministic=deterministic,
softmax_scale=sm_scale,
)
assert torch.equal(out, out_d)

if DEBUG:
print("cutlass_blackwell_fmha_func completed successfully!")

# Follow FlashAttention's numerical evaluation
# Compare outputs
if is_paged:
# Compare paged output with both reference and non paged output
self._allclose(out_paged, out_ref, out_pt)
self._allclose(out_paged, out, out_pt)
else:
self._allclose(out, out_ref, out_pt)

if deterministic:
# Rerun the test. The outputs must be bit-wise exact
out_d = func_to_test(
q,
cast(torch.Tensor, k_paged) if is_paged else k,
cast(torch.Tensor, v_paged) if is_paged else v,
causal=causal,
window_size=window_size,
seqlen_kv=seqlen_kv,
page_table=page_table if is_paged else None,
seqlen_k=seqlen_k,
deterministic=deterministic,
softmax_scale=sm_scale,
)
assert torch.equal(out, out_d)

if fwd_only:
return
Expand Down Expand Up @@ -730,69 +743,74 @@ def _execute_cutlass_blackwell_attn_varlen(
)

func_to_test = cutlass_blackwell_fmha_func
forward_test_ctx = nullcontext()
if use_compile:
func_to_test = torch.compile(func_to_test, fullgraph=True)
if is_paged:
assert k_paged is not None and v_paged is not None
out_unpad_paged = func_to_test(
q_unpad,
k_paged,
v_paged,
causal=causal,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seq_len_q=max_seqlen_q,
max_seq_len_k=max_seqlen_k,
page_table=page_table,
window_size=window_size,
deterministic=deterministic,
softmax_scale=sm_scale,
)
out_paged = output_pad_fn(out_unpad_paged)

out_unpad = func_to_test(
q_unpad,
k_unpad,
v_unpad,
causal=causal,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seq_len_q=max_seqlen_q,
max_seq_len_k=max_seqlen_k,
page_table=None,
window_size=window_size,
deterministic=deterministic,
softmax_scale=sm_scale,
)
out = output_pad_fn(out_unpad)

# Follow FlashAttention's numerical evaluation
# Compare outputs
if is_paged:
# Compare paged output with both reference and non paged output
self._allclose(out_paged, out_ref, out_pt)
self._allclose(out_paged, out, out_pt)
else:
self._allclose(out, out_ref, out_pt)
if fwd_only:
forward_test_ctx = torch.no_grad()

with forward_test_ctx:
if is_paged:
assert k_paged is not None and v_paged is not None
out_unpad_paged = func_to_test(
q_unpad,
k_paged,
v_paged,
causal=causal,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seq_len_q=max_seqlen_q,
max_seq_len_k=max_seqlen_k,
page_table=page_table,
window_size=window_size,
deterministic=deterministic,
softmax_scale=sm_scale,
)
out_paged = output_pad_fn(out_unpad_paged)

if deterministic:
# Rerun the test. The outputs must be bit-wise exact
out_unpad_d = func_to_test(
out_unpad = func_to_test(
q_unpad,
cast(torch.Tensor, k_paged) if is_paged else k_unpad,
cast(torch.Tensor, v_paged) if is_paged else v_unpad,
k_unpad,
v_unpad,
causal=causal,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seq_len_q=max_seqlen_q,
max_seq_len_k=max_seqlen_k,
page_table=page_table,
page_table=None,
window_size=window_size,
deterministic=deterministic,
softmax_scale=sm_scale,
)
out_d = output_pad_fn(out_unpad_d)
assert torch.equal(out, out_d)
out = output_pad_fn(out_unpad)

# Follow FlashAttention's numerical evaluation
# Compare outputs
if is_paged:
# Compare paged output with both reference and non paged output
self._allclose(out_paged, out_ref, out_pt)
self._allclose(out_paged, out, out_pt)
else:
self._allclose(out, out_ref, out_pt)

if deterministic:
# Rerun the test. The outputs must be bit-wise exact
out_unpad_d = func_to_test(
q_unpad,
cast(torch.Tensor, k_paged) if is_paged else k_unpad,
cast(torch.Tensor, v_paged) if is_paged else v_unpad,
causal=causal,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seq_len_q=max_seqlen_q,
max_seq_len_k=max_seqlen_k,
page_table=page_table,
window_size=window_size,
deterministic=deterministic,
softmax_scale=sm_scale,
)
out_d = output_pad_fn(out_unpad_d)
assert torch.equal(out, out_d)

if fwd_only:
return
Expand Down Expand Up @@ -1531,7 +1549,6 @@ def test_compile(
kv_heads = 2 if is_mqa else q_heads
batch_size = 2
seqlen_k = 128
kv_heads = 2
head_dim = 128
dtype = torch.bfloat16
causal = True
Expand All @@ -1554,6 +1571,7 @@ def test_compile(
dtype=dtype,
window_size=window_size,
sm_scale=None,
use_compile=True,
)
return

Expand All @@ -1572,6 +1590,7 @@ def test_compile(
deterministic=deterministic,
sm_scale=sm_scale,
is_paged=False,
use_compile=True,
)

@skip_cuda_lt_sm100
Expand Down
Loading