Skip to content

Commit 05858f6

Browse files
committed
Custom op to update cache for torch.cond
ghstack-source-id: a1ca30a ghstack-comment-id: 3683802199 Pull-Request: #16366
1 parent 9374916 commit 05858f6

File tree

3 files changed

+369
-2
lines changed

3 files changed

+369
-2
lines changed

extension/llm/custom_ops/custom_ops.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,16 @@
1212

1313
import logging
1414

15+
from typing import Tuple
16+
1517
import torch
1618

19+
from torch._inductor.lowering import lowerings as L, register_lowering
20+
1721
from torch.library import impl
1822

23+
aten = torch.ops.aten
24+
1925
try:
2026
op = torch.ops.llama.sdpa_with_kv_cache.default
2127
assert op is not None
@@ -387,3 +393,85 @@ def custom_quantized_sdpa_meta(
387393
)
388394

389395
return torch.empty(query.size(), dtype=torch.float32, device="meta")
396+
397+
398+
# 1) Define the custom op in the "executorch" namespace with name "alias"
399+
@torch.library.custom_op("executorch::alias", mutates_args=())
400+
def custom_alias(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
401+
# no copies, just pass-through
402+
return x, y
403+
404+
405+
# 2) FakeTensor kernel: describes output metadata for compile-time
406+
@custom_alias.register_fake
407+
def _(x, y):
408+
# For this op, outputs have exactly the same shape/dtype/device as inputs.
409+
# We just need *dummy* tensors with that metadata.
410+
out_x = torch.empty_like(x)
411+
out_y = torch.empty_like(y)
412+
return out_x, out_y
413+
414+
415+
@register_lowering(torch.ops.executorch.alias.default)
416+
def lowering_custom_alias(x, y):
417+
# x, y here are IR values (Inductor's internal representation).
418+
# Alias is logically a no-op – just pass them through.
419+
return x, y
420+
421+
422+
# Expecting cache shape: (B, H, S_max, D), value shape (B, H, S, D) where S <= S_max
423+
def _validate_cross_attn_cache_params(value: torch.Tensor, cache: torch.Tensor):
424+
torch._assert(value.dim() == 4, "value must be 4D")
425+
torch._assert(cache.dim() == 4, "cache must be 4D")
426+
# Cache shape: (B, H, S_max, D)
427+
# Value shape: (B, H, S, D)
428+
torch._assert(
429+
value.size(2) <= cache.size(2),
430+
f"value sequence length {value.size(2)} exceeds cache size {cache.size(2)}",
431+
)
432+
torch._assert(value.size(0) == cache.size(0), "batch size mismatch")
433+
torch._assert(value.size(1) == cache.size(1), "num heads mismatch")
434+
torch._assert(value.size(3) == cache.size(3), "head dim mismatch")
435+
torch._assert(value.dtype == cache.dtype, "dtype mismatch")
436+
437+
438+
# This is cheating: we delibrately NOT mark `cache` to be mutating so that this
439+
# custom op can be used in HOP such as `torch.cond`, where `torch.compile` requires
440+
# no aliasing or mutation in the branches. This is fine because we only care about inference.
441+
@torch.library.custom_op("executorch::update_cross_attn_cache", mutates_args=[])
442+
def _update_cross_attn_cache(value: torch.Tensor, cache: torch.Tensor) -> torch.Tensor:
443+
# Eager implementation
444+
_validate_cross_attn_cache_params(value, cache)
445+
446+
# Slice the cache to match value's sequence length and copy
447+
# cache shape: [B, H, S_max, D]
448+
# value shape: [B, H, S, D]
449+
cache[:, :, : value.size(2), :].copy_(value)
450+
# Return a clone of the cache to avoid aliasing with the input cache, so that we can still run exported program.
451+
return cache.clone()
452+
453+
454+
# Register the fake (meta) kernel
455+
@_update_cross_attn_cache.register_fake
456+
def _update_cross_attn_cache_fake(
457+
value: torch.Tensor, cache: torch.Tensor
458+
) -> torch.Tensor:
459+
_validate_cross_attn_cache_params(value, cache)
460+
return torch.empty_like(cache)
461+
462+
463+
# Register Inductor lowering
464+
@register_lowering(torch.ops.executorch.update_cross_attn_cache)
465+
def _update_cross_attn_cache_lowering(value, cache):
466+
# cache shape: [B, H, S_max, D]
467+
# value shape: [B, H, S, D]
468+
469+
# We need to slice the cache along dim 2 (sequence length)
470+
# slice(self, dim, start, end, step=1)
471+
seq_len = value.get_size()[2]
472+
cache_slice = L[aten.slice.Tensor](cache, 2, 0, seq_len, 1)
473+
474+
# Copy value into the slice
475+
L[aten.copy_.default](cache_slice, value)
476+
477+
return cache
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
import unittest
2+
3+
import torch
4+
5+
# Import the custom ops to ensure they are registered
6+
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
7+
8+
# Check CUDA availability once at module level
9+
CUDA_AVAILABLE = torch.cuda.is_available()
10+
11+
12+
class TestUpdateCrossAttnCache(unittest.TestCase):
13+
def test_update_cross_attn_cache(self):
14+
15+
# Create tensors
16+
# Cache: [B=2, H=1, S_max=4, D=4]
17+
cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
18+
# Value: [B=2, H=1, S=2, D=4] (S < S_max)
19+
value = torch.randn(2, 1, 2, 4, dtype=torch.float32)
20+
21+
# Compile a function that uses the op
22+
@torch.compile
23+
def fn(v, c):
24+
return torch.ops.executorch.update_cross_attn_cache(v, c)
25+
26+
# Run it
27+
out = fn(value, cache)
28+
29+
# Check correctness
30+
# The first 2 elements in dim 2 (sequence dim) should match value
31+
torch.testing.assert_close(
32+
cache[:, :, :2, :], value, msg="Cache slice not updated correctly"
33+
)
34+
35+
# Make sure out and cache are close. In eager they are the same objects.
36+
torch.testing.assert_close(
37+
out, cache, msg="Output and cache are different objects"
38+
)
39+
40+
# The rest should be zeros
41+
torch.testing.assert_close(
42+
cache[:, :, 2:, :],
43+
torch.zeros_like(cache[:, :, 2:, :]),
44+
msg="Rest of cache was modified",
45+
)
46+
47+
def test_update_cross_attn_cache_in_cond(self):
48+
# Create tensors
49+
50+
# Value: [B=2, H=1, S=2, D=4]
51+
value = torch.randn(2, 1, 2, 4, dtype=torch.float32)
52+
# Alternative value for false branch
53+
value_alt = torch.randn(2, 1, 2, 4, dtype=torch.float32)
54+
55+
# Define a function that uses the op inside torch.cond
56+
def fn_with_cond(pred, v1, v2, c):
57+
def true_fn(v1, v2, cache):
58+
return torch.ops.executorch.update_cross_attn_cache(v1, cache)
59+
60+
def false_fn(v1, v2, cache):
61+
return torch.ops.executorch.update_cross_attn_cache(v2, cache)
62+
63+
return torch.cond(pred, true_fn, false_fn, (v1, v2, c))
64+
65+
# Test with true condition
66+
pred_true = torch.tensor(True)
67+
cache_true = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
68+
69+
# Compile the function
70+
@torch.compile
71+
def compiled_fn(pred, v1, v2, c):
72+
return fn_with_cond(pred, v1, v2, c)
73+
74+
# Run with true condition
75+
compiled_fn(pred_true, value, value_alt, cache_true)
76+
77+
# Check that the true branch was executed (value was used)
78+
torch.testing.assert_close(
79+
cache_true[:, :, :2, :],
80+
value,
81+
msg="Cache not updated correctly in true branch",
82+
)
83+
84+
# Test with false condition
85+
pred_false = torch.tensor(False)
86+
cache_false = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
87+
88+
compiled_fn(pred_false, value, value_alt, cache_false)
89+
90+
# Check that the false branch was executed (value_alt was used)
91+
torch.testing.assert_close(
92+
cache_false[:, :, :2, :],
93+
value_alt,
94+
msg="Cache not updated correctly in false branch",
95+
)
96+
97+
def test_update_cross_attn_cache_export(self):
98+
99+
# Create tensors
100+
# Cache: [B=2, H=1, S_max=4, D=4]
101+
cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
102+
# Value: [B=2, H=1, S=2, D=4]
103+
value = torch.randn(2, 1, 2, 4, dtype=torch.float32)
104+
# Alternative value for false branch
105+
value_alt = torch.randn(2, 1, 2, 4, dtype=torch.float32)
106+
107+
# Define a module that uses torch.cond with the op
108+
class UpdateCacheCondModule(torch.nn.Module):
109+
def forward(self, pred, v1, v2, c):
110+
def true_fn(v1, v2, cache):
111+
return torch.ops.executorch.update_cross_attn_cache(v1, cache)
112+
113+
def false_fn(v1, v2, cache):
114+
return torch.ops.executorch.update_cross_attn_cache(v2, cache)
115+
116+
return torch.cond(pred, true_fn, false_fn, (v1, v2, c))
117+
118+
module = UpdateCacheCondModule()
119+
120+
# Export the module with true condition
121+
pred_true = torch.tensor(True)
122+
exported_program = torch.export.export(
123+
module,
124+
(pred_true, value, value_alt, cache),
125+
)
126+
127+
# Run the exported program with true condition
128+
cache_true = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
129+
exported_program.module()(pred_true, value, value_alt, cache_true)
130+
131+
# Check that the true branch was executed (value was used)
132+
torch.testing.assert_close(
133+
cache_true[:, :, :2, :],
134+
value,
135+
msg="Cache not updated correctly in true branch after export",
136+
)
137+
138+
# Run the exported program with false condition
139+
pred_false = torch.tensor(False)
140+
cache_false = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
141+
exported_program.module()(pred_false, value, value_alt, cache_false)
142+
143+
# Check that the false branch was executed (value_alt was used)
144+
torch.testing.assert_close(
145+
cache_false[:, :, :2, :],
146+
value_alt,
147+
msg="Cache not updated correctly in false branch after export",
148+
)
149+
150+
def test_update_cross_attn_cache_different_shapes(self):
151+
print("Testing executorch::update_cross_attn_cache with different shapes...")
152+
153+
# Test with different batch sizes and sequence lengths
154+
test_cases = [
155+
# (B, H, S_max, S, D)
156+
(1, 2, 10, 5, 8),
157+
(4, 4, 8, 3, 16),
158+
(2, 1, 16, 10, 32),
159+
]
160+
161+
for B, H, S_max, S, D in test_cases:
162+
# Cache: [B, H, S_max, D], Value: [B, H, S, D]
163+
cache = torch.zeros(B, H, S_max, D, dtype=torch.float32)
164+
value = torch.randn(B, H, S, D, dtype=torch.float32)
165+
166+
@torch.compile
167+
def fn(v, c):
168+
return torch.ops.executorch.update_cross_attn_cache(v, c)
169+
170+
fn(value, cache)
171+
172+
# Check that the first S positions in dim 2 are updated
173+
torch.testing.assert_close(
174+
cache[:, :, :S, :],
175+
value,
176+
msg=f"Failed for shape B={B}, H={H}, S_max={S_max}, S={S}, D={D}",
177+
)
178+
179+
# Check that the rest remain zeros
180+
if S < S_max:
181+
torch.testing.assert_close(
182+
cache[:, :, S:, :],
183+
torch.zeros_like(cache[:, :, S:, :]),
184+
msg=f"Remaining cache modified for shape B={B}, H={H}, S_max={S_max}, S={S}, D={D}",
185+
)
186+
187+
def test_update_cross_attn_cache_full_sequence(self):
188+
189+
# Cache: [B=2, H=1, S_max=4, D=4]
190+
cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
191+
# Value: [B=2, H=1, S=4, D=4] (S == S_max)
192+
value = torch.randn(2, 1, 4, 4, dtype=torch.float32)
193+
194+
@torch.compile
195+
def fn(v, c):
196+
return torch.ops.executorch.update_cross_attn_cache(v, c)
197+
198+
fn(value, cache)
199+
200+
# The entire cache should match value
201+
torch.testing.assert_close(
202+
cache, value, msg="Cache not fully updated when S == S_max"
203+
)
204+
205+
@unittest.skipUnless(CUDA_AVAILABLE, "CUDA not available")
206+
def test_alias_and_update_cross_attn_cache_with_cond_triton(self):
207+
"""Test combining alias and update_cross_attn_cache ops with torch.cond,
208+
lowered to Triton on CUDA. True branch uses alias, false branch uses
209+
update_cross_attn_cache."""
210+
211+
# Create CUDA tensors
212+
# Value: [B=2, H=1, S=2, D=4]
213+
value = torch.randn(2, 1, 2, 4, dtype=torch.float32, device="cuda")
214+
# Extra tensor for alias op
215+
extra = torch.randn(2, 1, 4, 4, dtype=torch.float32, device="cuda")
216+
217+
# Define a function that uses different ops in each branch
218+
def fn_with_cond(pred, v, extra_tensor, c):
219+
def true_fn(v, extra_tensor, cache):
220+
# True branch: use alias op only
221+
aliased_cache, aliased_extra = torch.ops.executorch.alias(
222+
cache, extra_tensor
223+
)
224+
# Return sum of aliased tensors (no cache mutation)
225+
return aliased_cache + aliased_extra
226+
227+
def false_fn(v, extra_tensor, cache):
228+
# False branch: use update_cross_attn_cache op only
229+
updated = torch.ops.executorch.update_cross_attn_cache(v, cache)
230+
return updated
231+
232+
return torch.cond(pred, true_fn, false_fn, (v, extra_tensor, c))
233+
234+
# Compile the function with Triton backend
235+
@torch.compile(backend="inductor")
236+
def compiled_fn(pred, v, extra_tensor, c):
237+
return fn_with_cond(pred, v, extra_tensor, c)
238+
239+
# Test with true condition (alias branch)
240+
pred_true = torch.tensor(True, device="cuda")
241+
cache_true = torch.zeros(2, 1, 4, 4, dtype=torch.float32, device="cuda")
242+
243+
result_true = compiled_fn(pred_true, value, extra, cache_true)
244+
245+
# Check that the true branch was executed (alias: cache + extra)
246+
expected_true = cache_true + extra
247+
torch.testing.assert_close(
248+
result_true,
249+
expected_true,
250+
msg="Result incorrect in true branch (alias) with CUDA/Triton",
251+
)
252+
253+
# Cache should remain unchanged in true branch (alias doesn't mutate)
254+
torch.testing.assert_close(
255+
cache_true,
256+
torch.zeros(2, 1, 4, 4, dtype=torch.float32, device="cuda"),
257+
msg="Cache should not be mutated in true branch (alias)",
258+
)
259+
260+
# Test with false condition (update_cross_attn_cache branch)
261+
pred_false = torch.tensor(False, device="cuda")
262+
cache_false = torch.zeros(2, 1, 4, 4, dtype=torch.float32, device="cuda")
263+
264+
compiled_fn(pred_false, value, extra, cache_false)
265+
266+
# Check that the false branch was executed (update_cross_attn_cache)
267+
# The cache should be updated with value in the first S positions
268+
torch.testing.assert_close(
269+
cache_false[:, :, :2, :],
270+
value,
271+
msg="Cache not updated correctly in false branch with CUDA/Triton",
272+
)
273+
274+
# The rest of the cache should remain zeros
275+
torch.testing.assert_close(
276+
cache_false[:, :, 2:, :],
277+
torch.zeros(2, 1, 2, 4, dtype=torch.float32, device="cuda"),
278+
msg="Rest of cache was modified in false branch",
279+
)

torch_pin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
TORCH_VERSION = "2.10.0"
2-
NIGHTLY_VERSION = "dev20251120"
1+
TORCH_VERSION = "2.11.0"
2+
NIGHTLY_VERSION = "dev20251222"

0 commit comments

Comments
 (0)