-
Notifications
You must be signed in to change notification settings - Fork 629
[TOSA] MultiheadAttention legalization #4382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
a98526f to
cf45a2e
Compare
Lallapallooza
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for patch, few comments.
projects/pt1/test/python/scaled_dot_product_attention_lowering.py
Outdated
Show resolved
Hide resolved
cf45a2e to
fd02d37
Compare
- Legalize Torch scaled_dot_product_attention into TOSA by adding the necessary patterns in TorchToTosa.cpp plus backend type-conversion hooks. - Introduce a detailed decomposition path for multi-head attention within DecomposeComplexOps.cpp, preparing inputs for TOSA lowering. - Expands the PT1 e2e suite with a dedicated multi-head attention MLIR/Python test and drop the corresponding xfails now that the path works. Signed-off-by: Cathal Corbett <cathal.corbett@arm.com> Change-Id: I96c17aefd25b979f1cf6e897d91d5a29f0a2fa85
fd02d37 to
9ef3ed1
Compare
| legalOpsSet.clear(); | ||
| legalOpsSet.insert(legalOps.begin(), legalOps.end()); | ||
|
|
||
| patterns.add<DecomposeAtenScaledDotProductAttentionOp>(context); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this pattern needed anymore with the change in fx_decomp_util?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please correct me if I misunderstand, but I believe we still need the MLIR-side pattern. The new entry in python/torch_mlir/extras/fx_decomp_util.py only affects the FX/ExportedProgram import path. Other frontends—TorchScript, AOTAutograd, or anyone who feeds raw Torch dialect into torch-mlir-opt—never touch that Python list, so they can still produce torch.aten.scaled_dot_product_attention. For those cases the rewrite in lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp is what lowers sdpa into the matmul/softmax pipeline so that downstream -convert-torch-to-tosa or -convert-torch-to-linalg keeps working.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on
torch-mlir/docs/development.md
Line 244 in 0844d4d
| This path doesn't give access to the current generation work that is being driven via the fx_importer |
fx_importer path is the only maintained path. Rest have been deprecated but the code still exists. Maybe @sjarus / @zjgarvey can confirm / correct that understanding and we can discuss if it's still valuable to have this decomposition pattern or we can rely on PyTorch's decomposition.
I don't have any flag on adding this, just want to make sure that it will actually be exercised.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the fx_importer is the only path we should expect to support.
I have found attention to be a bit frustrating, however. For example, running decompositions on an exported program with an sdpa op sometimes converts sdpa into a slightly different attention op- even when attention itself isn't getting decomposed. Merely running decompositions at all actually retraces the graph with a different tool, and may select different ops further varied based on other factors like the torch device used by the inputs.
In any case, I don't mind adding a decomposition pattern. We have a bit more control with a pattern like this as opposed to fx decompositions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the fx_importer is the only path we should expect to support.
I have found attention to be a bit frustrating, however. For example, running decompositions on an exported program with an sdpa op sometimes converts sdpa into a slightly different attention op- even when attention itself isn't getting decomposed. Merely running decompositions at all actually retraces the graph with a different tool, and may select different ops further varied based on other factors like the torch device used by the inputs.
In any case, I don't mind adding a decomposition pattern. We have a bit more control with a pattern like this as opposed to fx decompositions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sahas3 based on the above comment are you happy to keep the decomposition as is or would you prefer to remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reminder, I'm fine with adding the pattern.
The puzzle now is that if we add sdpa to the decomp list in fx_decomp_util, while we do have a LIT test locking down the decomposition pattern, it won't be tested via the e2e test. However the ops we are decomposing to with the C++ pattern are all locked down via e2e test (I hope :D ). So not being able to e2e test this C++ decomp pattern is probably fine? Thoughts @zjgarvey / @sjarus ?
On that note, we probably need a way to be able to control what ops to decompose with PyTorch decomposition at the e2e test level but that's out of scope of this PR.
sahas3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change. I'm fine with adding the decomposition pattern.
| SmallVector<int64_t> transposedShape(rankedSelf.getRank(), | ||
| ShapedType::kDynamic); | ||
| if (rankedSelf.hasStaticShape()) { | ||
| auto staticShape = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to_vector is redundant since the return type of makeShapeTorchCompatible is already SmallVector
| llvm::to_vector(makeShapeTorchCompatible(rankedSelf.getShape())); | ||
| auto dim0Index = static_cast<size_t>(dim0); | ||
| auto dim1Index = static_cast<size_t>(dim1); | ||
| if (dim0Index < staticShape.size() && dim1Index < staticShape.size()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this condition always guaranteed by the check on line
| if (!isValidDim(dim0, selfRank) || !isValidDim(dim1, selfRank)) |
| for (size_t i = 0; i < staticShape.size(); ++i) | ||
| transposedShape[i] = staticShape[i]; | ||
| } | ||
| auto rankedResult = RankedTensorType::get( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, you are computing the transposed shape for statically shaped inputs and using that to construct tosa::transposeOp. I think we can use tosa::CreateOpAndInfer with UnrankedTensorType::get(elemTy) and let the transposed op creation process infer the resultType instead of computing it here.
| target.addDynamicallyLegalOp<tensor::CastOp>([](tensor::CastOp op) -> bool { | ||
| auto sourceType = dyn_cast<RankedTensorType>(op.getSource().getType()); | ||
| auto resultType = dyn_cast<RankedTensorType>(op.getType()); | ||
| if (!sourceType || !resultType) | ||
| return true; | ||
| if (sourceType.getElementType() != resultType.getElementType()) | ||
| return true; | ||
| if (!sourceType.hasStaticShape()) | ||
| return true; | ||
| if (!resultType.hasStaticShape()) | ||
| return true; | ||
| if (sourceType == resultType) | ||
| return true; | ||
| return false; | ||
| }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am guessing this is needed because of the targetMaterialization change?
| legalOpsSet.clear(); | ||
| legalOpsSet.insert(legalOps.begin(), legalOps.end()); | ||
|
|
||
| patterns.add<DecomposeAtenScaledDotProductAttentionOp>(context); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| patterns.add<DecomposeAtenScaledDotProductAttentionOp>(context); | |
| addPatternIfTargetOpIsIllegal<DecomposeAtenScaledDotProductAttentionOp>(patterns); |
| if (static_cast<int64_t>(keySizes.size()) != queryRank || | ||
| static_cast<int64_t>(valueSizes.size()) != queryRank) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "expected query, key, and value to share rank"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This failure check should happen before any IR modification starts otherwise IR will be in a bad state as we've already introduced new ops but the original op is not replaced, eventually leading to the pass failing.
| if (keyTransposedSizes.size() < 2) | ||
| return rewriter.notifyMatchFailure( | ||
| op, "expected key tensor rank >= 2 for transpose"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this too has to be checker earlier
| if (!softmax) | ||
| return rewriter.notifyMatchFailure(op, | ||
| "failed to compute softmax scores"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one probably is fine. Looking at getSoftmaxResult I don't expect it to fail in this logic. I can't think of a good idea to ensure that though barring replicating the checks in getSoftmaxResult here early enough but that doesn't seem like a good idea.
| @@ -0,0 +1,29 @@ | |||
| // RUN: torch-mlir-opt %s -torch-decompose-complex-ops -convert-torch-to-tosa -split-input-file | FileCheck %s | |||
|
|
|||
| // Checks that scaled dot product attention (single-head configuration) lowers | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a better place for this test is
| // RUN: torch-mlir-opt -torch-decompose-complex-ops -split-input-file %s | FileCheck %s |
No description provided.