Skip to content

Conversation

@catcor01
Copy link
Contributor

No description provided.

@catcor01 catcor01 force-pushed the multihead_attention branch 2 times, most recently from a98526f to cf45a2e Compare November 25, 2025 08:59
Copy link
Collaborator

@Lallapallooza Lallapallooza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for patch, few comments.

- Legalize Torch scaled_dot_product_attention into TOSA by adding the necessary patterns
  in TorchToTosa.cpp plus backend type-conversion hooks.
- Introduce a detailed decomposition path for multi-head attention within DecomposeComplexOps.cpp,
  preparing inputs for TOSA lowering.
- Expands the PT1 e2e suite with a dedicated multi-head attention MLIR/Python test and
  drop the corresponding xfails now that the path works.

Signed-off-by: Cathal Corbett <cathal.corbett@arm.com>
Change-Id: I96c17aefd25b979f1cf6e897d91d5a29f0a2fa85
@catcor01 catcor01 force-pushed the multihead_attention branch from fd02d37 to 9ef3ed1 Compare December 16, 2025 11:35
legalOpsSet.clear();
legalOpsSet.insert(legalOps.begin(), legalOps.end());

patterns.add<DecomposeAtenScaledDotProductAttentionOp>(context);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this pattern needed anymore with the change in fx_decomp_util?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please correct me if I misunderstand, but I believe we still need the MLIR-side pattern. The new entry in python/torch_mlir/extras/fx_decomp_util.py only affects the FX/ExportedProgram import path. Other frontends—TorchScript, AOTAutograd, or anyone who feeds raw Torch dialect into torch-mlir-opt—never touch that Python list, so they can still produce torch.aten.scaled_dot_product_attention. For those cases the rewrite in lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp is what lowers sdpa into the matmul/softmax pipeline so that downstream -convert-torch-to-tosa or -convert-torch-to-linalg keeps working.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on

This path doesn't give access to the current generation work that is being driven via the fx_importer
IIUC, fx_importer path is the only maintained path. Rest have been deprecated but the code still exists. Maybe @sjarus / @zjgarvey can confirm / correct that understanding and we can discuss if it's still valuable to have this decomposition pattern or we can rely on PyTorch's decomposition.

I don't have any flag on adding this, just want to make sure that it will actually be exercised.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the fx_importer is the only path we should expect to support.

I have found attention to be a bit frustrating, however. For example, running decompositions on an exported program with an sdpa op sometimes converts sdpa into a slightly different attention op- even when attention itself isn't getting decomposed. Merely running decompositions at all actually retraces the graph with a different tool, and may select different ops further varied based on other factors like the torch device used by the inputs.

In any case, I don't mind adding a decomposition pattern. We have a bit more control with a pattern like this as opposed to fx decompositions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the fx_importer is the only path we should expect to support.

I have found attention to be a bit frustrating, however. For example, running decompositions on an exported program with an sdpa op sometimes converts sdpa into a slightly different attention op- even when attention itself isn't getting decomposed. Merely running decompositions at all actually retraces the graph with a different tool, and may select different ops further varied based on other factors like the torch device used by the inputs.

In any case, I don't mind adding a decomposition pattern. We have a bit more control with a pattern like this as opposed to fx decompositions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sahas3 based on the above comment are you happy to keep the decomposition as is or would you prefer to remove?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder, I'm fine with adding the pattern.

The puzzle now is that if we add sdpa to the decomp list in fx_decomp_util, while we do have a LIT test locking down the decomposition pattern, it won't be tested via the e2e test. However the ops we are decomposing to with the C++ pattern are all locked down via e2e test (I hope :D ). So not being able to e2e test this C++ decomp pattern is probably fine? Thoughts @zjgarvey / @sjarus ?

On that note, we probably need a way to be able to control what ops to decompose with PyTorch decomposition at the e2e test level but that's out of scope of this PR.

Copy link
Member

@sahas3 sahas3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change. I'm fine with adding the decomposition pattern.

SmallVector<int64_t> transposedShape(rankedSelf.getRank(),
ShapedType::kDynamic);
if (rankedSelf.hasStaticShape()) {
auto staticShape =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to_vector is redundant since the return type of makeShapeTorchCompatible is already SmallVector

llvm::to_vector(makeShapeTorchCompatible(rankedSelf.getShape()));
auto dim0Index = static_cast<size_t>(dim0);
auto dim1Index = static_cast<size_t>(dim1);
if (dim0Index < staticShape.size() && dim1Index < staticShape.size())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this condition always guaranteed by the check on line

if (!isValidDim(dim0, selfRank) || !isValidDim(dim1, selfRank))
?

for (size_t i = 0; i < staticShape.size(); ++i)
transposedShape[i] = staticShape[i];
}
auto rankedResult = RankedTensorType::get(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, you are computing the transposed shape for statically shaped inputs and using that to construct tosa::transposeOp. I think we can use tosa::CreateOpAndInfer with UnrankedTensorType::get(elemTy) and let the transposed op creation process infer the resultType instead of computing it here.

Comment on lines +9425 to +9439
target.addDynamicallyLegalOp<tensor::CastOp>([](tensor::CastOp op) -> bool {
auto sourceType = dyn_cast<RankedTensorType>(op.getSource().getType());
auto resultType = dyn_cast<RankedTensorType>(op.getType());
if (!sourceType || !resultType)
return true;
if (sourceType.getElementType() != resultType.getElementType())
return true;
if (!sourceType.hasStaticShape())
return true;
if (!resultType.hasStaticShape())
return true;
if (sourceType == resultType)
return true;
return false;
});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing this is needed because of the targetMaterialization change?

legalOpsSet.clear();
legalOpsSet.insert(legalOps.begin(), legalOps.end());

patterns.add<DecomposeAtenScaledDotProductAttentionOp>(context);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
patterns.add<DecomposeAtenScaledDotProductAttentionOp>(context);
addPatternIfTargetOpIsIllegal<DecomposeAtenScaledDotProductAttentionOp>(patterns);

Comment on lines +2380 to +2383
if (static_cast<int64_t>(keySizes.size()) != queryRank ||
static_cast<int64_t>(valueSizes.size()) != queryRank)
return rewriter.notifyMatchFailure(
op, "expected query, key, and value to share rank");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This failure check should happen before any IR modification starts otherwise IR will be in a bad state as we've already introduced new ops but the original op is not replaced, eventually leading to the pass failing.

Comment on lines +2416 to +2418
if (keyTransposedSizes.size() < 2)
return rewriter.notifyMatchFailure(
op, "expected key tensor rank >= 2 for transpose");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this too has to be checker earlier

Comment on lines +2492 to +2494
if (!softmax)
return rewriter.notifyMatchFailure(op,
"failed to compute softmax scores");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one probably is fine. Looking at getSoftmaxResult I don't expect it to fail in this logic. I can't think of a good idea to ensure that though barring replicating the checks in getSoftmaxResult here early enough but that doesn't seem like a good idea.

@@ -0,0 +1,29 @@
// RUN: torch-mlir-opt %s -torch-decompose-complex-ops -convert-torch-to-tosa -split-input-file | FileCheck %s

// Checks that scaled dot product attention (single-head configuration) lowers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better place for this test is

// RUN: torch-mlir-opt -torch-decompose-complex-ops -split-input-file %s | FileCheck %s
since the decomposition path is common across all backend paths.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants