Skip to content

Commit 9ef3ed1

Browse files
committed
[TOSA] MultiheadAttention legalization
- Legalize Torch scaled_dot_product_attention into TOSA by adding the necessary patterns in TorchToTosa.cpp plus backend type-conversion hooks. - Introduce a detailed decomposition path for multi-head attention within DecomposeComplexOps.cpp, preparing inputs for TOSA lowering. - Expands the PT1 e2e suite with a dedicated multi-head attention MLIR/Python test and drop the corresponding xfails now that the path works. Signed-off-by: Cathal Corbett <cathal.corbett@arm.com> Change-Id: I96c17aefd25b979f1cf6e897d91d5a29f0a2fa85
1 parent 244f4b6 commit 9ef3ed1

File tree

6 files changed

+306
-20
lines changed

6 files changed

+306
-20
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4016,8 +4016,28 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
40164016
transposedDims[dim0] = dim1;
40174017
transposedDims[dim1] = dim0;
40184018

4019+
Type resultType = getTypeConverter()->convertType(op.getType());
4020+
if (auto rankedSelf = dyn_cast<RankedTensorType>(selfType)) {
4021+
SmallVector<int64_t> transposedShape(rankedSelf.getRank(),
4022+
ShapedType::kDynamic);
4023+
if (rankedSelf.hasStaticShape()) {
4024+
auto staticShape =
4025+
llvm::to_vector(makeShapeTorchCompatible(rankedSelf.getShape()));
4026+
auto dim0Index = static_cast<size_t>(dim0);
4027+
auto dim1Index = static_cast<size_t>(dim1);
4028+
if (dim0Index < staticShape.size() && dim1Index < staticShape.size())
4029+
std::swap(staticShape[dim0Index], staticShape[dim1Index]);
4030+
for (size_t i = 0; i < staticShape.size(); ++i)
4031+
transposedShape[i] = staticShape[i];
4032+
}
4033+
auto rankedResult = RankedTensorType::get(
4034+
makeShapeLLVMCompatible(transposedShape), rankedSelf.getElementType());
4035+
if (auto converted = getTypeConverter()->convertType(rankedResult))
4036+
resultType = converted;
4037+
}
4038+
40194039
rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
4020-
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
4040+
op, resultType, adaptor.getSelf(),
40214041
rewriter.getDenseI32ArrayAttr(transposedDims));
40224042

40234043
return success();
@@ -9402,6 +9422,21 @@ void populateTorchToTosaConversionLegalOps(ConversionTarget &target) {
94029422
target.addLegalOp<ConstantDeviceOp>();
94039423
target.addLegalOp<PrimListConstructOp>();
94049424
target.addLegalOp<PrimTupleConstructOp>();
9425+
target.addDynamicallyLegalOp<tensor::CastOp>([](tensor::CastOp op) -> bool {
9426+
auto sourceType = dyn_cast<RankedTensorType>(op.getSource().getType());
9427+
auto resultType = dyn_cast<RankedTensorType>(op.getType());
9428+
if (!sourceType || !resultType)
9429+
return true;
9430+
if (sourceType.getElementType() != resultType.getElementType())
9431+
return true;
9432+
if (!sourceType.hasStaticShape())
9433+
return true;
9434+
if (!resultType.hasStaticShape())
9435+
return true;
9436+
if (sourceType == resultType)
9437+
return true;
9438+
return false;
9439+
});
94059440
}
94069441

94079442
std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 219 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2295,17 +2295,223 @@ class DecomposeAtenTraceOp : public OpRewritePattern<AtenTraceOp> {
22952295
};
22962296
} // namespace
22972297

2298+
static Value getSoftmaxResult(Operation *op, Value self, Value dim,
2299+
Type resultType, Type accumulatorType,
2300+
PatternRewriter &rewriter);
2301+
2302+
namespace {
2303+
// Decompose scaled dot product attention into matmul/softmax pipeline when
2304+
// there is no masking, dropout, causal, or GQA behaviour.
2305+
class DecomposeAtenScaledDotProductAttentionOp
2306+
: public OpRewritePattern<AtenScaledDotProductAttentionOp> {
2307+
public:
2308+
using OpRewritePattern::OpRewritePattern;
2309+
LogicalResult matchAndRewrite(AtenScaledDotProductAttentionOp op,
2310+
PatternRewriter &rewriter) const override {
2311+
Location loc = op.getLoc();
2312+
2313+
if (!isa<Torch::NoneType>(op.getAttnMask().getType()))
2314+
return rewriter.notifyMatchFailure(
2315+
op, "attention mask decomposition not implemented");
2316+
2317+
double dropoutP;
2318+
if (!matchPattern(op.getDropoutP(), m_TorchConstantFloat(&dropoutP)) ||
2319+
dropoutP != 0.0)
2320+
return rewriter.notifyMatchFailure(
2321+
op, "expected dropout_p to be the constant 0.0");
2322+
2323+
bool isCausal;
2324+
if (!matchPattern(op.getIsCausal(), m_TorchConstantBool(&isCausal)) ||
2325+
isCausal)
2326+
return rewriter.notifyMatchFailure(op,
2327+
"causal attention not supported yet");
2328+
2329+
bool enableGqa;
2330+
if (!matchPattern(op.getEnableGqa(), m_TorchConstantBool(&enableGqa)) ||
2331+
enableGqa)
2332+
return rewriter.notifyMatchFailure(op,
2333+
"grouped-query attention unsupported");
2334+
2335+
Value query = op.getQuery();
2336+
Value key = op.getKey();
2337+
Value value = op.getValue();
2338+
2339+
auto queryTensorType = dyn_cast<BaseTensorType>(query.getType());
2340+
auto keyTensorType = dyn_cast<BaseTensorType>(key.getType());
2341+
auto valueTensorType = dyn_cast<BaseTensorType>(value.getType());
2342+
if (!queryTensorType || !keyTensorType || !valueTensorType)
2343+
return rewriter.notifyMatchFailure(op, "expected tensor inputs");
2344+
if (!queryTensorType.hasSizes() || !keyTensorType.hasSizes() ||
2345+
!valueTensorType.hasSizes())
2346+
return rewriter.notifyMatchFailure(
2347+
op, "expected tensor inputs to have known shapes");
2348+
auto queryValueTensorType = dyn_cast<ValueTensorType>(queryTensorType);
2349+
auto keyValueTensorType = dyn_cast<ValueTensorType>(keyTensorType);
2350+
auto valueValueTensorType = dyn_cast<ValueTensorType>(valueTensorType);
2351+
if (!queryValueTensorType || !keyValueTensorType || !valueValueTensorType)
2352+
return rewriter.notifyMatchFailure(op, "expected value tensor semantics");
2353+
if (!queryValueTensorType.hasDtype() || !keyValueTensorType.hasDtype() ||
2354+
!valueValueTensorType.hasDtype())
2355+
return rewriter.notifyMatchFailure(
2356+
op, "expected tensor inputs to have dtypes");
2357+
Type queryDtype = queryValueTensorType.getOptionalDtype();
2358+
if (queryDtype != keyValueTensorType.getOptionalDtype() ||
2359+
queryDtype != valueValueTensorType.getOptionalDtype())
2360+
return rewriter.notifyMatchFailure(
2361+
op, "expected query, key, and value to share dtype");
2362+
2363+
Value oneInt =
2364+
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1));
2365+
Value zeroInt =
2366+
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0));
2367+
Value rank = AtenDimOp::create(rewriter, loc, query);
2368+
Value lastDim = AtenSubIntOp::create(rewriter, loc, rank, oneInt);
2369+
Value headDim = AtenSizeIntOp::create(rewriter, loc, query, lastDim);
2370+
Value seqDimIndex = AtenSubIntOp::create(rewriter, loc, lastDim, oneInt);
2371+
Value seqLen = AtenSizeIntOp::create(rewriter, loc, query, seqDimIndex);
2372+
Value keySeqLen = AtenSizeIntOp::create(rewriter, loc, key, seqDimIndex);
2373+
ArrayRef<int64_t> querySizes = queryValueTensorType.getSizes();
2374+
int64_t queryRank = querySizes.size();
2375+
if (queryRank < 3 || queryRank > 4)
2376+
return rewriter.notifyMatchFailure(
2377+
op, "expected query tensor rank to be 3 or 4");
2378+
ArrayRef<int64_t> keySizes = keyValueTensorType.getSizes();
2379+
ArrayRef<int64_t> valueSizes = valueValueTensorType.getSizes();
2380+
if (static_cast<int64_t>(keySizes.size()) != queryRank ||
2381+
static_cast<int64_t>(valueSizes.size()) != queryRank)
2382+
return rewriter.notifyMatchFailure(
2383+
op, "expected query, key, and value to share rank");
2384+
bool hasExplicitHeadDim = queryRank == 4;
2385+
Value numHeadsSize =
2386+
hasExplicitHeadDim
2387+
? (Value)AtenSizeIntOp::create(rewriter, loc, query, oneInt)
2388+
: oneInt;
2389+
Value batchSize = AtenSizeIntOp::create(rewriter, loc, query, zeroInt);
2390+
auto listIntType =
2391+
Torch::ListType::get(Torch::IntType::get(rewriter.getContext()));
2392+
2393+
auto getDimValue = [&](int64_t staticDim, Value fallback) -> Value {
2394+
if (staticDim != Torch::kUnknownSize)
2395+
return ConstantIntOp::create(rewriter, loc,
2396+
rewriter.getI64IntegerAttr(staticDim));
2397+
return fallback;
2398+
};
2399+
2400+
Value scaleFloat;
2401+
if (isa<Torch::NoneType>(op.getScale().getType())) {
2402+
Value sqrtHeadDim = AtenSqrtIntOp::create(rewriter, loc, headDim);
2403+
Value oneFloat =
2404+
ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0));
2405+
scaleFloat = AtenDivFloatOp::create(rewriter, loc, oneFloat, sqrtHeadDim);
2406+
} else {
2407+
scaleFloat = op.getScale();
2408+
}
2409+
2410+
Value negTwo =
2411+
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-2));
2412+
Value negOne =
2413+
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1));
2414+
2415+
SmallVector<int64_t> keyTransposedSizes(keySizes.begin(), keySizes.end());
2416+
if (keyTransposedSizes.size() < 2)
2417+
return rewriter.notifyMatchFailure(
2418+
op, "expected key tensor rank >= 2 for transpose");
2419+
std::swap(keyTransposedSizes[keyTransposedSizes.size() - 1],
2420+
keyTransposedSizes[keyTransposedSizes.size() - 2]);
2421+
ArrayRef<int64_t> keyTransposedRef(keyTransposedSizes);
2422+
std::optional<ArrayRef<int64_t>> keyTransposedOpt(keyTransposedRef);
2423+
Type keyTransposedType = keyValueTensorType.getWithSizesAndDtypeAndSparsity(
2424+
keyTransposedSizes, keyValueTensorType.getOptionalDtype(),
2425+
keyValueTensorType.getOptionalSparsity());
2426+
Value keyTransposed = AtenTransposeIntOp::create(
2427+
rewriter, loc, keyTransposedType, key, negTwo, negOne);
2428+
SmallVector<Value> keyDims;
2429+
auto getOrFallback = [&](ArrayRef<int64_t> staticDims, unsigned idx,
2430+
Value fallback) -> Value {
2431+
return getDimValue(idx < staticDims.size() ? staticDims[idx]
2432+
: Torch::kUnknownSize,
2433+
fallback);
2434+
};
2435+
keyDims.push_back(getOrFallback(keyTransposedSizes, 0, batchSize));
2436+
if (hasExplicitHeadDim) {
2437+
keyDims.push_back(getOrFallback(keyTransposedSizes, 1, numHeadsSize));
2438+
keyDims.push_back(getOrFallback(keyTransposedSizes, 2, headDim));
2439+
keyDims.push_back(getOrFallback(keyTransposedSizes, 3, keySeqLen));
2440+
} else {
2441+
keyDims.push_back(getOrFallback(keyTransposedSizes, 1, headDim));
2442+
keyDims.push_back(getOrFallback(keyTransposedSizes, 2, keySeqLen));
2443+
}
2444+
Value keyTransposeShapeList =
2445+
PrimListConstructOp::create(rewriter, loc, listIntType, keyDims);
2446+
keyTransposed = AtenViewOp::create(rewriter, loc, keyTransposedType,
2447+
keyTransposed, keyTransposeShapeList);
2448+
2449+
auto getStaticDim = [](ArrayRef<int64_t> sizes, int64_t index) {
2450+
if (index < 0)
2451+
index += sizes.size();
2452+
if (index < 0 || index >= static_cast<int64_t>(sizes.size()))
2453+
return Torch::kUnknownSize;
2454+
return sizes[index];
2455+
};
2456+
int64_t queryBatchStatic = getStaticDim(querySizes, 0);
2457+
int64_t querySeqStatic = getStaticDim(querySizes, -2);
2458+
int64_t keySeqStatic = getStaticDim(keySizes, -2);
2459+
int64_t queryHeadsStatic =
2460+
hasExplicitHeadDim ? getStaticDim(querySizes, 1) : 1;
2461+
SmallVector<int64_t, 4> scoresSizes;
2462+
if (hasExplicitHeadDim)
2463+
scoresSizes.assign(
2464+
{queryBatchStatic, queryHeadsStatic, querySeqStatic, keySeqStatic});
2465+
else
2466+
scoresSizes.assign({queryBatchStatic, querySeqStatic, keySeqStatic});
2467+
Type scoresType = ValueTensorType::get(
2468+
op->getContext(),
2469+
ArrayRef<int64_t>(scoresSizes.begin(), scoresSizes.end()),
2470+
queryValueTensorType.getOptionalDtype(),
2471+
queryValueTensorType.getOptionalSparsity());
2472+
Value scores =
2473+
AtenMatmulOp::create(rewriter, loc, scoresType, query, keyTransposed);
2474+
SmallVector<Value> scoresDims;
2475+
scoresDims.push_back(getDimValue(scoresSizes[0], batchSize));
2476+
unsigned seqIndex = 1;
2477+
if (hasExplicitHeadDim) {
2478+
scoresDims.push_back(getDimValue(scoresSizes[1], numHeadsSize));
2479+
seqIndex = 2;
2480+
}
2481+
scoresDims.push_back(getDimValue(scoresSizes[seqIndex], seqLen));
2482+
scoresDims.push_back(getDimValue(scoresSizes.back(), keySeqLen));
2483+
Value scoresShapeList =
2484+
PrimListConstructOp::create(rewriter, loc, listIntType, scoresDims);
2485+
scores =
2486+
AtenViewOp::create(rewriter, loc, scoresType, scores, scoresShapeList);
2487+
Value scaledScores =
2488+
AtenMulScalarOp::create(rewriter, loc, scoresType, scores, scaleFloat);
2489+
2490+
Value softmax = getSoftmaxResult(op.getOperation(), scaledScores, negOne,
2491+
scoresType, scoresType, rewriter);
2492+
if (!softmax)
2493+
return rewriter.notifyMatchFailure(op,
2494+
"failed to compute softmax scores");
2495+
2496+
Value output =
2497+
AtenMatmulOp::create(rewriter, loc, op.getType(), softmax, value);
2498+
2499+
rewriter.replaceOp(op, output);
2500+
return success();
2501+
}
2502+
};
2503+
} // namespace
2504+
22982505
// Calculates the softmax function on the given `input` tensor. Softmax(x) =
22992506
// exp(x)/sum(exp(x)).
23002507
// To avoid overflow we use the following decomposition rule:
23012508
// x_max = max(input, dim, keepdim = True)
23022509
// unnorm = aten.exp(input - x_max)
23032510
// softmax = unnorm / sum(unnorm, dim, keepdim = True)
2304-
template <typename OpTy>
2305-
static Value getSoftmaxResult(OpTy op, Value self, Type resultType,
2306-
Type accumulatorType, PatternRewriter &rewriter) {
2307-
Location loc = op.getLoc();
2308-
Value dim = op.getDim();
2511+
static Value getSoftmaxResult(Operation *op, Value self, Value dim,
2512+
Type resultType, Type accumulatorType,
2513+
PatternRewriter &rewriter) {
2514+
Location loc = op->getLoc();
23092515
if (resultType != accumulatorType)
23102516
self = convertTensorToDtype(rewriter, loc, self, accumulatorType);
23112517
Value xMax =
@@ -2362,8 +2568,9 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
23622568

23632569
Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype);
23642570

2365-
Value result = getSoftmaxResult(op, self, resultTensorType,
2366-
accumulatorTensorType, rewriter);
2571+
Value result =
2572+
getSoftmaxResult(op.getOperation(), self, op.getDim(), resultTensorType,
2573+
accumulatorTensorType, rewriter);
23672574
if (!result)
23682575
return failure();
23692576
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
@@ -2411,8 +2618,9 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
24112618

24122619
Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype);
24132620

2414-
Value result = getSoftmaxResult(op, self, resultTensorType,
2415-
accumulatorTensorType, rewriter);
2621+
Value result =
2622+
getSoftmaxResult(op.getOperation(), self, op.getDim(), resultTensorType,
2623+
accumulatorTensorType, rewriter);
24162624
if (!result)
24172625
return op.emitError("failed to get softmax result");
24182626
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, resultTensorType,
@@ -13084,6 +13292,8 @@ class DecomposeComplexOpsPass
1308413292
legalOpsSet.clear();
1308513293
legalOpsSet.insert(legalOps.begin(), legalOps.end());
1308613294

13295+
patterns.add<DecomposeAtenScaledDotProductAttentionOp>(context);
13296+
1308713297
addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>(
1308813298
patterns);
1308913299
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);

lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//===----------------------------------------------------------------------===//
99

1010
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
11+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1112
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
1213

1314
using namespace mlir;
@@ -40,6 +41,25 @@ static void setupValueTensorToBuiltinTensorConversion(
4041
return {};
4142
return ToBuiltinTensorOp::create(builder, loc, type, inputs[0]);
4243
});
44+
typeConverter.addTargetMaterialization([](OpBuilder &builder, Type type,
45+
ValueRange inputs,
46+
Location loc) -> Value {
47+
if (inputs.size() != 1)
48+
return Value();
49+
auto fromType = dyn_cast<RankedTensorType>(inputs[0].getType());
50+
auto toType = dyn_cast<RankedTensorType>(type);
51+
if (!fromType || !toType)
52+
return Value();
53+
if (fromType == toType)
54+
return inputs[0];
55+
if (fromType.getElementType() != toType.getElementType())
56+
return Value();
57+
if (!toType.hasStaticShape())
58+
return Value();
59+
if (!tensor::CastOp::areCastCompatible(inputs[0].getType(), toType))
60+
return Value();
61+
return tensor::CastOp::create(builder, loc, toType, inputs[0]);
62+
});
4363
auto sourceMaterialization = [](OpBuilder &builder,
4464
Torch::ValueTensorType type,
4565
ValueRange inputs, Location loc) -> Value {

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,8 @@
5050
"ScaledDotProductAttentionBoolMaskModule_basic",
5151
"ScaledDotProductAttentionDifferentCausalModule_basic",
5252
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
53-
"ScaledDotProductAttentionDifferentModule_basic",
5453
"ScaledDotProductAttentionMaskModule_basic",
5554
"ScaledDotProductAttentionSameCausalModule_basic",
56-
"ScaledDotProductAttentionSameDynamicModule_basic",
57-
"ScaledDotProductAttentionSameModule_basic",
5855
}
5956

6057
LINALG_CRASHING_SET = {
@@ -953,11 +950,8 @@
953950
"ScaledDotProductAttentionBoolMaskModule_basic",
954951
"ScaledDotProductAttentionDifferentCausalModule_basic",
955952
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
956-
"ScaledDotProductAttentionDifferentModule_basic",
957953
"ScaledDotProductAttentionMaskModule_basic",
958954
"ScaledDotProductAttentionSameCausalModule_basic",
959-
"ScaledDotProductAttentionSameDynamicModule_basic",
960-
"ScaledDotProductAttentionSameModule_basic",
961955
"SubIntModule_basic",
962956
"TensorToIntZeroRank_basic",
963957
"UpSampleNearest2dDynamicFactor_basic",
@@ -3978,11 +3972,8 @@
39783972
"ScaledDotProductAttentionBoolMaskModule_basic",
39793973
"ScaledDotProductAttentionDifferentCausalModule_basic",
39803974
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
3981-
"ScaledDotProductAttentionDifferentModule_basic",
39823975
"ScaledDotProductAttentionMaskModule_basic",
39833976
"ScaledDotProductAttentionSameCausalModule_basic",
3984-
"ScaledDotProductAttentionSameDynamicModule_basic",
3985-
"ScaledDotProductAttentionSameModule_basic",
39863977
"ScaledDotProductAttentionGQAModule_basic",
39873978
# error: 'tosa.scatter' op requires dimensions K >= W
39883979
"IndexPut1DFloatNonAccumulateModule_basic",
@@ -4887,7 +4878,6 @@
48874878
# REMOVE WHEN ENABLE_GQA IS ADDED
48884879
"ScaledDotProductAttentionBoolMaskModule_basic",
48894880
"ScaledDotProductAttentionSameCausalModule_basic",
4890-
"ScaledDotProductAttentionSameDynamicModule_basic",
48914881
"ScatterAddDynamicModule_basic",
48924882
"ScatterReduceFloatMaxModule",
48934883
"ScatterReduceFloatMaxModuleIncludeSelf",

0 commit comments

Comments
 (0)