@@ -2295,17 +2295,223 @@ class DecomposeAtenTraceOp : public OpRewritePattern<AtenTraceOp> {
22952295};
22962296} // namespace
22972297
2298+ static Value getSoftmaxResult(Operation *op, Value self, Value dim,
2299+ Type resultType, Type accumulatorType,
2300+ PatternRewriter &rewriter);
2301+
2302+ namespace {
2303+ // Decompose scaled dot product attention into matmul/softmax pipeline when
2304+ // there is no masking, dropout, causal, or GQA behaviour.
2305+ class DecomposeAtenScaledDotProductAttentionOp
2306+ : public OpRewritePattern<AtenScaledDotProductAttentionOp> {
2307+ public:
2308+ using OpRewritePattern::OpRewritePattern;
2309+ LogicalResult matchAndRewrite(AtenScaledDotProductAttentionOp op,
2310+ PatternRewriter &rewriter) const override {
2311+ Location loc = op.getLoc();
2312+
2313+ if (!isa<Torch::NoneType>(op.getAttnMask().getType()))
2314+ return rewriter.notifyMatchFailure(
2315+ op, "attention mask decomposition not implemented");
2316+
2317+ double dropoutP;
2318+ if (!matchPattern(op.getDropoutP(), m_TorchConstantFloat(&dropoutP)) ||
2319+ dropoutP != 0.0)
2320+ return rewriter.notifyMatchFailure(
2321+ op, "expected dropout_p to be the constant 0.0");
2322+
2323+ bool isCausal;
2324+ if (!matchPattern(op.getIsCausal(), m_TorchConstantBool(&isCausal)) ||
2325+ isCausal)
2326+ return rewriter.notifyMatchFailure(op,
2327+ "causal attention not supported yet");
2328+
2329+ bool enableGqa;
2330+ if (!matchPattern(op.getEnableGqa(), m_TorchConstantBool(&enableGqa)) ||
2331+ enableGqa)
2332+ return rewriter.notifyMatchFailure(op,
2333+ "grouped-query attention unsupported");
2334+
2335+ Value query = op.getQuery();
2336+ Value key = op.getKey();
2337+ Value value = op.getValue();
2338+
2339+ auto queryTensorType = dyn_cast<BaseTensorType>(query.getType());
2340+ auto keyTensorType = dyn_cast<BaseTensorType>(key.getType());
2341+ auto valueTensorType = dyn_cast<BaseTensorType>(value.getType());
2342+ if (!queryTensorType || !keyTensorType || !valueTensorType)
2343+ return rewriter.notifyMatchFailure(op, "expected tensor inputs");
2344+ if (!queryTensorType.hasSizes() || !keyTensorType.hasSizes() ||
2345+ !valueTensorType.hasSizes())
2346+ return rewriter.notifyMatchFailure(
2347+ op, "expected tensor inputs to have known shapes");
2348+ auto queryValueTensorType = dyn_cast<ValueTensorType>(queryTensorType);
2349+ auto keyValueTensorType = dyn_cast<ValueTensorType>(keyTensorType);
2350+ auto valueValueTensorType = dyn_cast<ValueTensorType>(valueTensorType);
2351+ if (!queryValueTensorType || !keyValueTensorType || !valueValueTensorType)
2352+ return rewriter.notifyMatchFailure(op, "expected value tensor semantics");
2353+ if (!queryValueTensorType.hasDtype() || !keyValueTensorType.hasDtype() ||
2354+ !valueValueTensorType.hasDtype())
2355+ return rewriter.notifyMatchFailure(
2356+ op, "expected tensor inputs to have dtypes");
2357+ Type queryDtype = queryValueTensorType.getOptionalDtype();
2358+ if (queryDtype != keyValueTensorType.getOptionalDtype() ||
2359+ queryDtype != valueValueTensorType.getOptionalDtype())
2360+ return rewriter.notifyMatchFailure(
2361+ op, "expected query, key, and value to share dtype");
2362+
2363+ Value oneInt =
2364+ ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1));
2365+ Value zeroInt =
2366+ ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0));
2367+ Value rank = AtenDimOp::create(rewriter, loc, query);
2368+ Value lastDim = AtenSubIntOp::create(rewriter, loc, rank, oneInt);
2369+ Value headDim = AtenSizeIntOp::create(rewriter, loc, query, lastDim);
2370+ Value seqDimIndex = AtenSubIntOp::create(rewriter, loc, lastDim, oneInt);
2371+ Value seqLen = AtenSizeIntOp::create(rewriter, loc, query, seqDimIndex);
2372+ Value keySeqLen = AtenSizeIntOp::create(rewriter, loc, key, seqDimIndex);
2373+ ArrayRef<int64_t> querySizes = queryValueTensorType.getSizes();
2374+ int64_t queryRank = querySizes.size();
2375+ if (queryRank < 3 || queryRank > 4)
2376+ return rewriter.notifyMatchFailure(
2377+ op, "expected query tensor rank to be 3 or 4");
2378+ ArrayRef<int64_t> keySizes = keyValueTensorType.getSizes();
2379+ ArrayRef<int64_t> valueSizes = valueValueTensorType.getSizes();
2380+ if (static_cast<int64_t>(keySizes.size()) != queryRank ||
2381+ static_cast<int64_t>(valueSizes.size()) != queryRank)
2382+ return rewriter.notifyMatchFailure(
2383+ op, "expected query, key, and value to share rank");
2384+ bool hasExplicitHeadDim = queryRank == 4;
2385+ Value numHeadsSize =
2386+ hasExplicitHeadDim
2387+ ? (Value)AtenSizeIntOp::create(rewriter, loc, query, oneInt)
2388+ : oneInt;
2389+ Value batchSize = AtenSizeIntOp::create(rewriter, loc, query, zeroInt);
2390+ auto listIntType =
2391+ Torch::ListType::get(Torch::IntType::get(rewriter.getContext()));
2392+
2393+ auto getDimValue = [&](int64_t staticDim, Value fallback) -> Value {
2394+ if (staticDim != Torch::kUnknownSize)
2395+ return ConstantIntOp::create(rewriter, loc,
2396+ rewriter.getI64IntegerAttr(staticDim));
2397+ return fallback;
2398+ };
2399+
2400+ Value scaleFloat;
2401+ if (isa<Torch::NoneType>(op.getScale().getType())) {
2402+ Value sqrtHeadDim = AtenSqrtIntOp::create(rewriter, loc, headDim);
2403+ Value oneFloat =
2404+ ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0));
2405+ scaleFloat = AtenDivFloatOp::create(rewriter, loc, oneFloat, sqrtHeadDim);
2406+ } else {
2407+ scaleFloat = op.getScale();
2408+ }
2409+
2410+ Value negTwo =
2411+ ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-2));
2412+ Value negOne =
2413+ ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1));
2414+
2415+ SmallVector<int64_t> keyTransposedSizes(keySizes.begin(), keySizes.end());
2416+ if (keyTransposedSizes.size() < 2)
2417+ return rewriter.notifyMatchFailure(
2418+ op, "expected key tensor rank >= 2 for transpose");
2419+ std::swap(keyTransposedSizes[keyTransposedSizes.size() - 1],
2420+ keyTransposedSizes[keyTransposedSizes.size() - 2]);
2421+ ArrayRef<int64_t> keyTransposedRef(keyTransposedSizes);
2422+ std::optional<ArrayRef<int64_t>> keyTransposedOpt(keyTransposedRef);
2423+ Type keyTransposedType = keyValueTensorType.getWithSizesAndDtypeAndSparsity(
2424+ keyTransposedSizes, keyValueTensorType.getOptionalDtype(),
2425+ keyValueTensorType.getOptionalSparsity());
2426+ Value keyTransposed = AtenTransposeIntOp::create(
2427+ rewriter, loc, keyTransposedType, key, negTwo, negOne);
2428+ SmallVector<Value> keyDims;
2429+ auto getOrFallback = [&](ArrayRef<int64_t> staticDims, unsigned idx,
2430+ Value fallback) -> Value {
2431+ return getDimValue(idx < staticDims.size() ? staticDims[idx]
2432+ : Torch::kUnknownSize,
2433+ fallback);
2434+ };
2435+ keyDims.push_back(getOrFallback(keyTransposedSizes, 0, batchSize));
2436+ if (hasExplicitHeadDim) {
2437+ keyDims.push_back(getOrFallback(keyTransposedSizes, 1, numHeadsSize));
2438+ keyDims.push_back(getOrFallback(keyTransposedSizes, 2, headDim));
2439+ keyDims.push_back(getOrFallback(keyTransposedSizes, 3, keySeqLen));
2440+ } else {
2441+ keyDims.push_back(getOrFallback(keyTransposedSizes, 1, headDim));
2442+ keyDims.push_back(getOrFallback(keyTransposedSizes, 2, keySeqLen));
2443+ }
2444+ Value keyTransposeShapeList =
2445+ PrimListConstructOp::create(rewriter, loc, listIntType, keyDims);
2446+ keyTransposed = AtenViewOp::create(rewriter, loc, keyTransposedType,
2447+ keyTransposed, keyTransposeShapeList);
2448+
2449+ auto getStaticDim = [](ArrayRef<int64_t> sizes, int64_t index) {
2450+ if (index < 0)
2451+ index += sizes.size();
2452+ if (index < 0 || index >= static_cast<int64_t>(sizes.size()))
2453+ return Torch::kUnknownSize;
2454+ return sizes[index];
2455+ };
2456+ int64_t queryBatchStatic = getStaticDim(querySizes, 0);
2457+ int64_t querySeqStatic = getStaticDim(querySizes, -2);
2458+ int64_t keySeqStatic = getStaticDim(keySizes, -2);
2459+ int64_t queryHeadsStatic =
2460+ hasExplicitHeadDim ? getStaticDim(querySizes, 1) : 1;
2461+ SmallVector<int64_t, 4> scoresSizes;
2462+ if (hasExplicitHeadDim)
2463+ scoresSizes.assign(
2464+ {queryBatchStatic, queryHeadsStatic, querySeqStatic, keySeqStatic});
2465+ else
2466+ scoresSizes.assign({queryBatchStatic, querySeqStatic, keySeqStatic});
2467+ Type scoresType = ValueTensorType::get(
2468+ op->getContext(),
2469+ ArrayRef<int64_t>(scoresSizes.begin(), scoresSizes.end()),
2470+ queryValueTensorType.getOptionalDtype(),
2471+ queryValueTensorType.getOptionalSparsity());
2472+ Value scores =
2473+ AtenMatmulOp::create(rewriter, loc, scoresType, query, keyTransposed);
2474+ SmallVector<Value> scoresDims;
2475+ scoresDims.push_back(getDimValue(scoresSizes[0], batchSize));
2476+ unsigned seqIndex = 1;
2477+ if (hasExplicitHeadDim) {
2478+ scoresDims.push_back(getDimValue(scoresSizes[1], numHeadsSize));
2479+ seqIndex = 2;
2480+ }
2481+ scoresDims.push_back(getDimValue(scoresSizes[seqIndex], seqLen));
2482+ scoresDims.push_back(getDimValue(scoresSizes.back(), keySeqLen));
2483+ Value scoresShapeList =
2484+ PrimListConstructOp::create(rewriter, loc, listIntType, scoresDims);
2485+ scores =
2486+ AtenViewOp::create(rewriter, loc, scoresType, scores, scoresShapeList);
2487+ Value scaledScores =
2488+ AtenMulScalarOp::create(rewriter, loc, scoresType, scores, scaleFloat);
2489+
2490+ Value softmax = getSoftmaxResult(op.getOperation(), scaledScores, negOne,
2491+ scoresType, scoresType, rewriter);
2492+ if (!softmax)
2493+ return rewriter.notifyMatchFailure(op,
2494+ "failed to compute softmax scores");
2495+
2496+ Value output =
2497+ AtenMatmulOp::create(rewriter, loc, op.getType(), softmax, value);
2498+
2499+ rewriter.replaceOp(op, output);
2500+ return success();
2501+ }
2502+ };
2503+ } // namespace
2504+
22982505// Calculates the softmax function on the given `input` tensor. Softmax(x) =
22992506// exp(x)/sum(exp(x)).
23002507// To avoid overflow we use the following decomposition rule:
23012508// x_max = max(input, dim, keepdim = True)
23022509// unnorm = aten.exp(input - x_max)
23032510// softmax = unnorm / sum(unnorm, dim, keepdim = True)
2304- template <typename OpTy>
2305- static Value getSoftmaxResult(OpTy op, Value self, Type resultType,
2306- Type accumulatorType, PatternRewriter &rewriter) {
2307- Location loc = op.getLoc();
2308- Value dim = op.getDim();
2511+ static Value getSoftmaxResult(Operation *op, Value self, Value dim,
2512+ Type resultType, Type accumulatorType,
2513+ PatternRewriter &rewriter) {
2514+ Location loc = op->getLoc();
23092515 if (resultType != accumulatorType)
23102516 self = convertTensorToDtype(rewriter, loc, self, accumulatorType);
23112517 Value xMax =
@@ -2362,8 +2568,9 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
23622568
23632569 Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype);
23642570
2365- Value result = getSoftmaxResult(op, self, resultTensorType,
2366- accumulatorTensorType, rewriter);
2571+ Value result =
2572+ getSoftmaxResult(op.getOperation(), self, op.getDim(), resultTensorType,
2573+ accumulatorTensorType, rewriter);
23672574 if (!result)
23682575 return failure();
23692576 rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
@@ -2411,8 +2618,9 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
24112618
24122619 Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype);
24132620
2414- Value result = getSoftmaxResult(op, self, resultTensorType,
2415- accumulatorTensorType, rewriter);
2621+ Value result =
2622+ getSoftmaxResult(op.getOperation(), self, op.getDim(), resultTensorType,
2623+ accumulatorTensorType, rewriter);
24162624 if (!result)
24172625 return op.emitError("failed to get softmax result");
24182626 rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, resultTensorType,
@@ -13084,6 +13292,8 @@ class DecomposeComplexOpsPass
1308413292 legalOpsSet.clear();
1308513293 legalOpsSet.insert(legalOps.begin(), legalOps.end());
1308613294
13295+ patterns.add<DecomposeAtenScaledDotProductAttentionOp>(context);
13296+
1308713297 addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>(
1308813298 patterns);
1308913299 addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
0 commit comments