diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index e068259..e6914d5 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -690,11 +690,6 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, Value b = gemmOpAdaptor.getB(); Value c = gemmOpAdaptor.getC(); - if (gemmOpAdaptor.getTransA()) { - gemmOp.emitOpError("requires transA=false before tiled Spatial Gemm lowering"); - return failure(); - } - auto aType = dyn_cast(a.getType()); auto bType = dyn_cast(b.getType()); auto outType = dyn_cast(gemmOp.getY().getType()); @@ -725,9 +720,12 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, return failure(); } - const int64_t numOutRows = outType.getDimSize(0); - const int64_t numOutCols = outType.getDimSize(1); - const int64_t reductionSize = aType.getDimSize(1); + if (gemmOpAdaptor.getTransA()) { + auto aShape = aType.getShape(); + auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType(), aType.getEncoding()); + a = ONNXTransposeOp::create(rewriter, loc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})).getResult(); + aType = transposedType; + } if (gemmOpAdaptor.getTransB()) { auto bShape = bType.getShape(); @@ -736,6 +734,10 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, bType = transposedType; } + const int64_t numOutRows = outType.getDimSize(0); + const int64_t numOutCols = outType.getDimSize(1); + const int64_t reductionSize = aType.getDimSize(1); + if (!isCompileTimeComputable(b)) { bool hasC = hasGemmBias(c); float alpha = gemmOpAdaptor.getAlpha().convertToFloat(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index 922c05f..4b888b1 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -22,13 +22,87 @@ namespace { static FailureOr> inferSupportedBatchShape(ArrayRef lhsBatchShape, ArrayRef rhsBatchShape) { - if (lhsBatchShape.empty()) - return SmallVector(rhsBatchShape.begin(), rhsBatchShape.end()); - if (rhsBatchShape.empty()) - return SmallVector(lhsBatchShape.begin(), lhsBatchShape.end()); - if (!llvm::equal(lhsBatchShape, rhsBatchShape)) - return failure(); - return SmallVector(lhsBatchShape.begin(), lhsBatchShape.end()); + const int64_t resultRank = std::max(lhsBatchShape.size(), rhsBatchShape.size()); + SmallVector resultShape(resultRank, 1); + for (int64_t resultIndex = resultRank - 1, lhsIndex = lhsBatchShape.size() - 1, rhsIndex = rhsBatchShape.size() - 1; + resultIndex >= 0; + --resultIndex, --lhsIndex, --rhsIndex) { + const int64_t lhsDim = lhsIndex >= 0 ? lhsBatchShape[lhsIndex] : 1; + const int64_t rhsDim = rhsIndex >= 0 ? rhsBatchShape[rhsIndex] : 1; + if (lhsDim != rhsDim && lhsDim != 1 && rhsDim != 1) + return failure(); + resultShape[resultIndex] = std::max(lhsDim, rhsDim); + } + return resultShape; +} + +static int64_t mapStaticBroadcastedBatchIndex(int64_t outputBatchIndex, + ArrayRef sourceBatchShape, + ArrayRef outputBatchShape) { + if (sourceBatchShape.empty() || getStaticShapeElementCount(sourceBatchShape) == 1) + return 0; + if (llvm::equal(sourceBatchShape, outputBatchShape)) + return outputBatchIndex; + + SmallVector outputStrides = computeRowMajorStrides(outputBatchShape); + SmallVector sourceStrides = computeRowMajorStrides(sourceBatchShape); + int64_t sourceFlatIndex = 0; + for (int64_t sourceDimIndex = 0; sourceDimIndex < static_cast(sourceBatchShape.size()); ++sourceDimIndex) { + if (sourceBatchShape[sourceDimIndex] == 1) + continue; + const int64_t outputDimIndex = outputBatchShape.size() - sourceBatchShape.size() + sourceDimIndex; + const int64_t outputDimStride = outputStrides.empty() ? 1 : outputStrides[outputDimIndex]; + const int64_t outputDimIndexValue = outputDimStride == 1 + ? outputBatchIndex % outputBatchShape[outputDimIndex] + : (outputBatchIndex / outputDimStride) % outputBatchShape[outputDimIndex]; + sourceFlatIndex += outputDimIndexValue * sourceStrides[sourceDimIndex]; + } + return sourceFlatIndex; +} + +static Value computeFlatBatchIndexCoordinate( + Value flatBatchIndex, ArrayRef batchShape, int64_t dimIndex, PatternRewriter& rewriter, Location loc) { + if (batchShape[dimIndex] == 1) + return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + + const int64_t dimStride = dimIndex + 1 == static_cast(batchShape.size()) + ? 1 + : getStaticShapeElementCount(batchShape.drop_front(dimIndex + 1)); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value dimCoordinate = flatBatchIndex; + if (dimStride != 1) + dimCoordinate = affineFloorDivConst(rewriter, loc, dimCoordinate, dimStride, anchorOp); + return affineModConst(rewriter, loc, dimCoordinate, batchShape[dimIndex], anchorOp); +} + +static Value mapOutputBatchIndexToSourceBatchIndex(Value outputBatchIndex, + ArrayRef sourceBatchShape, + ArrayRef outputBatchShape, + PatternRewriter& rewriter, + Location loc) { + if (sourceBatchShape.empty() || getStaticShapeElementCount(sourceBatchShape) == 1) + return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + if (llvm::equal(sourceBatchShape, outputBatchShape)) + return outputBatchIndex; + + Value sourceBatchIndex = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + SmallVector sourceStrides = computeRowMajorStrides(sourceBatchShape); + for (int64_t sourceDimIndex = 0; sourceDimIndex < static_cast(sourceBatchShape.size()); ++sourceDimIndex) { + if (sourceBatchShape[sourceDimIndex] == 1) + continue; + const int64_t outputDimIndex = outputBatchShape.size() - sourceBatchShape.size() + sourceDimIndex; + Value outputCoordinate = + computeFlatBatchIndexCoordinate(outputBatchIndex, outputBatchShape, outputDimIndex, rewriter, loc); + Value contribution = sourceStrides[sourceDimIndex] == 1 + ? outputCoordinate + : affineMulConst(rewriter, + loc, + outputCoordinate, + sourceStrides[sourceDimIndex], + rewriter.getInsertionBlock()->getParentOp()); + sourceBatchIndex = arith::AddIOp::create(rewriter, loc, sourceBatchIndex, contribution); + } + return sourceBatchIndex; } static Value @@ -67,6 +141,52 @@ expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, Patt return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded); } +static Value createMatrixFromVector(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) { + auto buildExpanded = [&](Value input) -> Value { + return tensor::ExpandShapeOp::create(rewriter, + loc, + resultType, + input, + SmallVector { + {0, 1} + }); + }; + return materializeOrComputeUnary(value, resultType, rewriter, loc, buildExpanded); +} + +static SmallVector buildCollapseReassociation(ArrayRef removedAxes) { + SmallVector reassociation; + ReassociationIndices currentGroup; + for (auto [axis, removeAxis] : llvm::enumerate(removedAxes)) { + currentGroup.push_back(axis); + if (!removeAxis) { + reassociation.push_back(currentGroup); + currentGroup.clear(); + } + } + + if (!currentGroup.empty()) { + if (reassociation.empty()) + reassociation.push_back(std::move(currentGroup)); + else + reassociation.back().append(currentGroup.begin(), currentGroup.end()); + } + return reassociation; +} + +static Value squeezeUnitDims( + Value value, RankedTensorType resultType, ArrayRef removedAxes, PatternRewriter& rewriter, Location loc) { + if (cast(value.getType()) == resultType) + return value; + + SmallVector reassociation = + resultType.getRank() == 0 ? SmallVector {} : buildCollapseReassociation(removedAxes); + auto buildCollapsed = [&](Value input) -> Value { + return tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation).getResult(); + }; + return materializeOrComputeUnary(value, resultType, rewriter, loc, buildCollapsed); +} + static Value ensureBatchedTensor( Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) { auto type = cast(value.getType()); @@ -171,8 +291,11 @@ static Value createPaddedBatchedInputCompute(Value input, return computeOp.getResult(0); } -static FailureOr materializePaddedBatchedWeight( - Value value, int64_t sourceBatch, int64_t targetBatch, RankedTensorType resultType, PatternRewriter& rewriter) { +static FailureOr materializePaddedBatchedWeight(Value value, + ArrayRef sourceBatchShape, + ArrayRef targetBatchShape, + RankedTensorType resultType, + PatternRewriter& rewriter) { auto sourceType = cast(value.getType()); if (sourceType == resultType) return value; @@ -183,13 +306,15 @@ static FailureOr materializePaddedBatchedWeight( const int64_t sourceRows = sourceType.getRank() == 2 ? sourceType.getDimSize(0) : sourceType.getDimSize(1); const int64_t sourceCols = sourceType.getRank() == 2 ? sourceType.getDimSize(1) : sourceType.getDimSize(2); + const int64_t targetBatch = targetBatchShape.empty() ? 1 : getStaticShapeElementCount(targetBatchShape); const int64_t targetRows = resultType.getDimSize(1); const int64_t targetCols = resultType.getDimSize(2); SmallVector sourceValues(denseAttr.getValues()); SmallVector resultValues(resultType.getNumElements(), rewriter.getZeroAttr(resultType.getElementType())); for (int64_t batchIdx = 0; batchIdx < targetBatch; ++batchIdx) { - const int64_t sourceBatchIdx = sourceType.getRank() == 2 ? 0 : (sourceBatch == 1 ? 0 : batchIdx); + const int64_t sourceBatchIdx = + sourceType.getRank() == 2 ? 0 : mapStaticBroadcastedBatchIndex(batchIdx, sourceBatchShape, targetBatchShape); const int64_t sourceBatchBase = sourceType.getRank() == 2 ? 0 : sourceBatchIdx * sourceRows * sourceCols; const int64_t targetBatchBase = batchIdx * targetRows * targetCols; for (int64_t row = 0; row < sourceRows; ++row) @@ -202,16 +327,18 @@ static FailureOr materializePaddedBatchedWeight( } static Value extractBatchedATile(Value a, - int64_t sourceBatchCount, - Value batch, + ArrayRef sourceBatchShape, + ArrayRef outputBatchShape, + Value outputBatchIndex, Value row, Value kOffset, RankedTensorType aTileType, PatternRewriter& rewriter, Location loc) { auto aSliceType = RankedTensorType::get({1, 1, aTileType.getDimSize(1)}, aTileType.getElementType()); - SmallVector offsets { - sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), row, kOffset}; + Value sourceBatchIndex = + mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc); + SmallVector offsets {OpFoldResult(sourceBatchIndex), row, kOffset}; SmallVector sizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(aTileType.getDimSize(1))}; auto slice = @@ -227,8 +354,9 @@ static Value extractBatchedATile(Value a, } static Value extractBatchedBTile(Value b, - int64_t sourceBatchCount, - Value batch, + ArrayRef sourceBatchShape, + ArrayRef outputBatchShape, + Value outputBatchIndex, Value kOffset, Value hOffset, RankedTensorType bTileType, @@ -236,8 +364,9 @@ static Value extractBatchedBTile(Value b, Location loc) { auto bSliceType = RankedTensorType::get({1, bTileType.getDimSize(0), bTileType.getDimSize(1)}, bTileType.getElementType()); - SmallVector offsets { - sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), kOffset, hOffset}; + Value sourceBatchIndex = + mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc); + SmallVector offsets {OpFoldResult(sourceBatchIndex), kOffset, hOffset}; SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(bTileType.getDimSize(0)), rewriter.getIndexAttr(bTileType.getDimSize(1))}; @@ -262,9 +391,10 @@ static Value getBatchLaneIndex( static FailureOr createBatchedVmmBatch(Value a, Value b, RankedTensorType aType, - int64_t aBatchCount, + ArrayRef aBatchShape, RankedTensorType bType, - int64_t bBatchCount, + ArrayRef bBatchShape, + ArrayRef outputBatchShape, RankedTensorType partialPiecesType, int64_t numOutRows, int64_t numKSlices, @@ -298,10 +428,10 @@ static FailureOr createBatchedVmmBatch(Value a, auto pieceType = RankedTensorType::get({1, static_cast(crossbarSize.getValue())}, partialPiecesType.getElementType()); - Value aTile = - extractBatchedATile(args.inputs.front(), aBatchCount, batch, row, kOffset, aTileType, rewriter, loc); - Value bTile = - extractBatchedBTile(args.weights.front(), bBatchCount, batch, kOffset, hOffset, bTileType, rewriter, loc); + Value aTile = extractBatchedATile( + args.inputs.front(), aBatchShape, outputBatchShape, batch, row, kOffset, aTileType, rewriter, loc); + Value bTile = extractBatchedBTile( + args.weights.front(), bBatchShape, outputBatchShape, batch, kOffset, hOffset, bTileType, rewriter, loc); Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult(); SmallVector pieceOffsets {args.lane, rewriter.getIndexAttr(0)}; @@ -315,17 +445,17 @@ static FailureOr createBatchedVmmBatch(Value a, } static Value extractDynamicBatchedBColumn(Value matrix, - int64_t sourceBatchCount, - Value batch, + ArrayRef sourceBatchShape, + ArrayRef outputBatchShape, + Value outputBatchIndex, Value column, RankedTensorType vectorType, PatternRewriter& rewriter, Location loc) { auto columnSliceType = RankedTensorType::get({1, vectorType.getDimSize(1), 1}, vectorType.getElementType()); - SmallVector offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) - : OpFoldResult(batch), - rewriter.getIndexAttr(0), - column}; + Value sourceBatchIndex = + mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc); + SmallVector offsets {OpFoldResult(sourceBatchIndex), rewriter.getIndexAttr(0), column}; SmallVector sizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)}; SmallVector strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; @@ -350,17 +480,17 @@ static Value extractDynamicBatchedBColumn(Value matrix, } static Value extractDynamicBatchedRowVector(Value matrix, - int64_t sourceBatchCount, - Value batch, + ArrayRef sourceBatchShape, + ArrayRef outputBatchShape, + Value outputBatchIndex, Value row, RankedTensorType vectorType, PatternRewriter& rewriter, Location loc) { auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType()); - SmallVector offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) - : OpFoldResult(batch), - row, - rewriter.getIndexAttr(0)}; + Value sourceBatchIndex = + mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc); + SmallVector offsets {OpFoldResult(sourceBatchIndex), row, rewriter.getIndexAttr(0)}; SmallVector sizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))}; auto rowSlice = @@ -376,9 +506,10 @@ static Value extractDynamicBatchedRowVector(Value matrix, } static FailureOr createBatchedVvdmulBatch(Value a, - int64_t aBatchCount, + ArrayRef aBatchShape, Value b, - int64_t bBatchCount, + ArrayRef bBatchShape, + ArrayRef outputBatchShape, RankedTensorType aType, RankedTensorType bType, RankedTensorType scalarPiecesType, @@ -406,10 +537,10 @@ static FailureOr createBatchedVvdmulBatch(Value a, auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType()); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); - Value aVector = - extractDynamicBatchedRowVector(args.inputs[0], aBatchCount, batch, row, vectorType, rewriter, loc); - Value bVector = - extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc); + Value aVector = extractDynamicBatchedRowVector( + args.inputs[0], aBatchShape, outputBatchShape, batch, row, vectorType, rewriter, loc); + Value bVector = extractDynamicBatchedBColumn( + args.inputs[1], bBatchShape, outputBatchShape, batch, column, vectorType, rewriter, loc); Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult(); SmallVector outputOffsets {args.lane, rewriter.getIndexAttr(0)}; SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; @@ -629,11 +760,17 @@ static FailureOr createBatchedReductionCompute(Value partialPieces, return computeOp->getResult(0); } -struct MatMulShapeInfo { +struct NormalizedMatMulInfo { RankedTensorType lhsType; RankedTensorType rhsType; RankedTensorType outType; - SmallVector batchShape; + RankedTensorType normalizedLhsType; + RankedTensorType normalizedRhsType; + SmallVector lhsBatchShape; + SmallVector rhsBatchShape; + SmallVector outputBatchShape; + bool lhsWasVector; + bool rhsWasVector; int64_t lhsBatch; int64_t rhsBatch; int64_t batch; @@ -642,46 +779,170 @@ struct MatMulShapeInfo { int64_t n; }; -static FailureOr analyzeMatMulShape(ONNXMatMulOp matmulOp) { +struct MatMulLoweringPlan { + Value lhs; + Value rhs; + RankedTensorType lhsType; + RankedTensorType rhsType; + SmallVector lhsBatchShape; + SmallVector rhsBatchShape; + SmallVector outputBatchShape; + int64_t lhsBatch; + int64_t rhsBatch; + int64_t batch; + int64_t m; + int64_t k; + int64_t n; + bool transposedResult; +}; + +static SmallVector computeExpectedMatMulOutputShape( + ArrayRef batchShape, int64_t m, int64_t n, bool lhsWasVector, bool rhsWasVector) { + SmallVector shape(batchShape.begin(), batchShape.end()); + if (lhsWasVector && rhsWasVector) + return shape; + if (lhsWasVector) { + shape.push_back(n); + return shape; + } + if (rhsWasVector) { + shape.push_back(m); + return shape; + } + shape.push_back(m); + shape.push_back(n); + return shape; +} + +static FailureOr analyzeMatMulShape(ONNXMatMulOp matmulOp) { auto lhsType = dyn_cast(matmulOp.getA().getType()); auto rhsType = dyn_cast(matmulOp.getB().getType()); auto outType = dyn_cast(matmulOp.getY().getType()); if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape() || !outType.hasStaticShape()) return failure(); - if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2) + if (lhsType.getRank() < 1 || rhsType.getRank() < 1) return failure(); if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType)) return failure(); - SmallVector lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2); - SmallVector rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2); - auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape); - if (failed(batchShape)) + const bool lhsWasVector = lhsType.getRank() == 1; + const bool rhsWasVector = rhsType.getRank() == 1; + auto normalizedLhsType = + lhsWasVector ? RankedTensorType::get({1, lhsType.getDimSize(0)}, lhsType.getElementType(), lhsType.getEncoding()) + : lhsType; + auto normalizedRhsType = + rhsWasVector ? RankedTensorType::get({rhsType.getDimSize(0), 1}, rhsType.getElementType(), rhsType.getEncoding()) + : rhsType; + + SmallVector lhsBatchShape(normalizedLhsType.getShape().begin(), normalizedLhsType.getShape().end() - 2); + SmallVector rhsBatchShape(normalizedRhsType.getShape().begin(), normalizedRhsType.getShape().end() - 2); + auto outputBatchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape); + if (failed(outputBatchShape)) return failure(); const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape); const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape); - const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape); - const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2); - const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1); - const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2); - const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1); + const int64_t batch = outputBatchShape->empty() ? 1 : getStaticShapeElementCount(*outputBatchShape); + const int64_t m = normalizedLhsType.getDimSize(normalizedLhsType.getRank() - 2); + const int64_t k = normalizedLhsType.getDimSize(normalizedLhsType.getRank() - 1); + const int64_t rhsK = normalizedRhsType.getDimSize(normalizedRhsType.getRank() - 2); + const int64_t n = normalizedRhsType.getDimSize(normalizedRhsType.getRank() - 1); if (k != rhsK) return failure(); - if (outType.getRank() == 2) { - if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n) - return failure(); - } - else { - SmallVector outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2); - if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m - || outType.getDimSize(outType.getRank() - 1) != n) - return failure(); + if (SmallVector(outType.getShape().begin(), outType.getShape().end()) + != computeExpectedMatMulOutputShape(*outputBatchShape, m, n, lhsWasVector, rhsWasVector)) { + return failure(); } - return MatMulShapeInfo {lhsType, rhsType, outType, *batchShape, lhsBatch, rhsBatch, batch, m, k, n}; + return NormalizedMatMulInfo {lhsType, + rhsType, + outType, + normalizedLhsType, + normalizedRhsType, + lhsBatchShape, + rhsBatchShape, + *outputBatchShape, + lhsWasVector, + rhsWasVector, + lhsBatch, + rhsBatch, + batch, + m, + k, + n}; +} + +static MatMulLoweringPlan buildLoweringPlan(Value normalizedLhs, + Value normalizedRhs, + const NormalizedMatMulInfo& info, + bool useTransposedForm, + PatternRewriter& rewriter, + Location loc) { + MatMulLoweringPlan plan {normalizedLhs, + normalizedRhs, + cast(normalizedLhs.getType()), + cast(normalizedRhs.getType()), + info.lhsBatchShape, + info.rhsBatchShape, + info.outputBatchShape, + info.lhsBatch, + info.rhsBatch, + info.batch, + info.m, + info.k, + info.n, + false}; + if (!useTransposedForm) + return plan; + + plan.lhs = transposeLastTwoDims(normalizedRhs, rewriter, loc); + plan.rhs = transposeLastTwoDims(normalizedLhs, rewriter, loc); + plan.lhsType = cast(plan.lhs.getType()); + plan.rhsType = cast(plan.rhs.getType()); + std::swap(plan.lhsBatchShape, plan.rhsBatchShape); + std::swap(plan.lhsBatch, plan.rhsBatch); + plan.m = info.n; + plan.n = info.m; + plan.transposedResult = true; + return plan; +} + +static Value normalizeMatMulOperand( + Value value, RankedTensorType normalizedType, bool wasVector, PatternRewriter& rewriter, Location loc) { + if (!wasVector) + return value; + return createMatrixFromVector(value, normalizedType, rewriter, loc); +} + +static Value finalizeNormalizedMatMulResult(Value value, + RankedTensorType directOutType, + const NormalizedMatMulInfo& info, + PatternRewriter& rewriter, + Location loc) { + // The direct lowered result is always [flatBatch, normalizedM, normalizedN]. + // Restore ONNX MatMul result rank by expanding right-aligned batch dimensions + // and removing the synthetic unit matrix axes introduced for vector operands. + Value result = value; + RankedTensorType currentType = directOutType; + if (info.outputBatchShape.size() > 1) { + SmallVector expandedShape(info.outputBatchShape.begin(), info.outputBatchShape.end()); + expandedShape.push_back(info.m); + expandedShape.push_back(info.n); + auto expandedType = RankedTensorType::get(expandedShape, info.outType.getElementType(), info.outType.getEncoding()); + result = expandBatchDims(result, expandedType, info.outputBatchShape.size(), rewriter, loc); + currentType = expandedType; + } + + SmallVector removedAxes(currentType.getRank(), false); + if (info.outputBatchShape.empty()) + removedAxes[0] = true; + if (info.lhsWasVector) + removedAxes[currentType.getRank() - 2] = true; + if (info.rhsWasVector) + removedAxes[currentType.getRank() - 1] = true; + return squeezeUnitDims(result, info.outType, removedAxes, rewriter, loc); } struct MatMulToGemm : OpRewritePattern { @@ -689,7 +950,7 @@ struct MatMulToGemm : OpRewritePattern { LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override { auto shapeInfo = analyzeMatMulShape(matmulOp); - if (failed(shapeInfo) || shapeInfo->outType.getRank() != 2) + if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector || !shapeInfo->outputBatchShape.empty()) return failure(); Location loc = matmulOp.getLoc(); @@ -742,61 +1003,56 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern { auto shapeInfo = analyzeMatMulShape(matmulOp); if (failed(shapeInfo)) return failure(); - if (shapeInfo->outType.getRank() == 2) + if (!shapeInfo->lhsWasVector && !shapeInfo->rhsWasVector && shapeInfo->outputBatchShape.empty()) return failure(); Location loc = matmulOp.getLoc(); - bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB()); + bool useTransposedForm = !shapeInfo->lhsWasVector && !shapeInfo->rhsWasVector + && isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB()); - Value lhs = collapseBatchDims(matmulOp.getA(), shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc); - Value rhs = collapseBatchDims(matmulOp.getB(), shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc); - int64_t lhsBatchForGemm = shapeInfo->lhsBatch; - int64_t rhsBatchForGemm = shapeInfo->rhsBatch; - int64_t gemmM = shapeInfo->m; - int64_t gemmK = shapeInfo->k; - int64_t gemmN = shapeInfo->n; - if (useTransposedForm) { - lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc); - lhsBatchForGemm = shapeInfo->rhsBatch; - rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc); - rhsBatchForGemm = shapeInfo->lhsBatch; - gemmM = shapeInfo->n; - gemmN = shapeInfo->m; - } + Value lhs = + normalizeMatMulOperand(matmulOp.getA(), shapeInfo->normalizedLhsType, shapeInfo->lhsWasVector, rewriter, loc); + Value rhs = + normalizeMatMulOperand(matmulOp.getB(), shapeInfo->normalizedRhsType, shapeInfo->rhsWasVector, rewriter, loc); + lhs = collapseBatchDims(lhs, shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc); + rhs = collapseBatchDims(rhs, shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc); + MatMulLoweringPlan plan = buildLoweringPlan(lhs, rhs, *shapeInfo, useTransposedForm, rewriter, loc); - lhs = ensureBatchedTensor(lhs, lhsBatchForGemm, gemmM, gemmK, rewriter, loc); - rhs = ensureBatchedTensor(rhs, rhsBatchForGemm, gemmK, gemmN, rewriter, loc); - auto lhsBatchedType = cast(lhs.getType()); - auto rhsBatchedType = cast(rhs.getType()); - auto directOutType = RankedTensorType::get({shapeInfo->batch, gemmM, gemmN}, shapeInfo->outType.getElementType()); + plan.lhs = ensureBatchedTensor(plan.lhs, plan.lhsBatch, plan.m, plan.k, rewriter, loc); + plan.rhs = ensureBatchedTensor(plan.rhs, plan.rhsBatch, plan.k, plan.n, rewriter, loc); + plan.lhsType = cast(plan.lhs.getType()); + plan.rhsType = cast(plan.rhs.getType()); + auto directOutType = RankedTensorType::get( + {plan.batch, plan.m, plan.n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding()); - if (isCompileTimeComputable(rhs)) { - const int64_t numKSlices = ceilIntegerDivide(gemmK, crossbarSize.getValue()); - const int64_t numOutHSlices = ceilIntegerDivide(gemmN, crossbarSize.getValue()); + if (isCompileTimeComputable(plan.rhs)) { + const int64_t numKSlices = ceilIntegerDivide(plan.k, crossbarSize.getValue()); + const int64_t numOutHSlices = ceilIntegerDivide(plan.n, crossbarSize.getValue()); const int64_t paddedReductionSize = numKSlices * static_cast(crossbarSize.getValue()); const int64_t paddedOutCols = numOutHSlices * static_cast(crossbarSize.getValue()); auto paddedLhsType = RankedTensorType::get( - {lhsBatchForGemm, gemmM, paddedReductionSize}, lhsBatchedType.getElementType(), lhsBatchedType.getEncoding()); - auto paddedRhsType = RankedTensorType::get({shapeInfo->batch, paddedReductionSize, paddedOutCols}, - rhsBatchedType.getElementType(), - rhsBatchedType.getEncoding()); + {plan.lhsBatch, plan.m, paddedReductionSize}, plan.lhsType.getElementType(), plan.lhsType.getEncoding()); + auto paddedRhsType = RankedTensorType::get( + {plan.batch, paddedReductionSize, paddedOutCols}, plan.rhsType.getElementType(), plan.rhsType.getEncoding()); auto paddedOutType = - RankedTensorType::get({shapeInfo->batch, gemmM, paddedOutCols}, shapeInfo->outType.getElementType()); + RankedTensorType::get({plan.batch, plan.m, paddedOutCols}, shapeInfo->outType.getElementType()); - auto paddedRhs = materializePaddedBatchedWeight(rhs, rhsBatchForGemm, shapeInfo->batch, paddedRhsType, rewriter); + auto paddedRhs = + materializePaddedBatchedWeight(plan.rhs, plan.rhsBatchShape, plan.outputBatchShape, paddedRhsType, rewriter); if (succeeded(paddedRhs)) { - Value paddedLhs = createPaddedBatchedInputCompute(lhs, paddedLhsType, rewriter, loc); - const int64_t laneCount = shapeInfo->batch * gemmM * numKSlices * numOutHSlices; + Value paddedLhs = createPaddedBatchedInputCompute(plan.lhs, paddedLhsType, rewriter, loc); + const int64_t laneCount = plan.batch * plan.m * numKSlices * numOutHSlices; auto partialPiecesType = RankedTensorType::get({laneCount, static_cast(crossbarSize.getValue())}, shapeInfo->outType.getElementType()); auto batchOp = createBatchedVmmBatch(paddedLhs, *paddedRhs, paddedLhsType, - lhsBatchForGemm, + plan.lhsBatchShape, paddedRhsType, - rhsBatchForGemm, + plan.rhsBatchShape, + plan.outputBatchShape, partialPiecesType, - gemmM, + plan.m, numKSlices, numOutHSlices, rewriter, @@ -807,34 +1063,35 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern { partialPiecesType, directOutType, paddedOutType, - shapeInfo->batch, + plan.batch, numKSlices, rewriter, loc); if (failed(result)) return failure(); Value finalResult = *result; - if (useTransposedForm) { - auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, + if (plan.transposedResult) { + auto transposedOutType = RankedTensorType::get({plan.batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding()); finalResult = ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1})) .getResult(); } - finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc); + finalResult = finalizeNormalizedMatMulResult(finalResult, directOutType, *shapeInfo, rewriter, loc); rewriter.replaceOp(matmulOp, finalResult); return success(); } } - const int64_t laneCount = shapeInfo->batch * gemmM * gemmN; + const int64_t laneCount = plan.batch * plan.m * plan.n; auto scalarPiecesType = RankedTensorType::get({laneCount, 1}, shapeInfo->outType.getElementType()); - auto batchOp = createBatchedVvdmulBatch(lhs, - lhsBatchForGemm, - rhs, - rhsBatchForGemm, - lhsBatchedType, - rhsBatchedType, + auto batchOp = createBatchedVvdmulBatch(plan.lhs, + plan.lhsBatchShape, + plan.rhs, + plan.rhsBatchShape, + plan.outputBatchShape, + plan.lhsType, + plan.rhsType, scalarPiecesType, directOutType, rewriter, @@ -846,15 +1103,15 @@ struct MatMulBatchedToSpatialComputes : OpRewritePattern { if (failed(result)) return failure(); Value finalResult = *result; - if (useTransposedForm) { - auto transposedOutType = RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, + if (plan.transposedResult) { + auto transposedOutType = RankedTensorType::get({plan.batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding()); finalResult = ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1})) .getResult(); } - finalResult = expandBatchDims(finalResult, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc); + finalResult = finalizeNormalizedMatMulResult(finalResult, directOutType, *shapeInfo, rewriter, loc); rewriter.replaceOp(matmulOp, finalResult); return success(); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp index bf4ddab..2f03115 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp @@ -238,14 +238,8 @@ static Value squeezeReducedAxes(Value keepdimsValue, ArrayRef reducedAxes, ConversionPatternRewriter& rewriter, Location loc) { - if (resultType.getRank() == 0) { - SmallVector indices(cast(keepdimsValue.getType()).getRank(), - getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0)); - Value element = tensor::ExtractOp::create(rewriter, loc, keepdimsValue, indices); - return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element}); - } - - auto reassociation = buildCollapseReassociation(reducedAxes); + SmallVector reassociation = + resultType.getRank() == 0 ? SmallVector {} : buildCollapseReassociation(reducedAxes); if (isCompileTimeComputable(keepdimsValue)) return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult(); diff --git a/validation/operations/add/scalar_runtime/add_scalar_runtime.onnx b/validation/operations/add/scalar_runtime/add_scalar_runtime.onnx deleted file mode 100644 index 4d138c1..0000000 Binary files a/validation/operations/add/scalar_runtime/add_scalar_runtime.onnx and /dev/null differ diff --git a/validation/operations/div/runtime_scalar_lhs/div_runtime_scalar_lhs.onnx b/validation/operations/div/runtime_scalar_lhs/div_runtime_scalar_lhs.onnx deleted file mode 100644 index 3401328..0000000 Binary files a/validation/operations/div/runtime_scalar_lhs/div_runtime_scalar_lhs.onnx and /dev/null differ diff --git a/validation/operations/gen_tests.py b/validation/operations/gen_tests.py index a1fdf18..ba09a7a 100644 --- a/validation/operations/gen_tests.py +++ b/validation/operations/gen_tests.py @@ -779,28 +779,6 @@ def matmul_matrix_vector(): save_model(model, "matmul/matrix_vector", "matmul_matrix_vector.onnx") -def matmul_vector_vector_dot(): - """Vector-vector MatMul producing a scalar output.""" - A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1024]) - Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, []) - B = numpy_helper.from_array(np.random.default_rng(97).uniform(-1, 1, (1024,)).astype(np.float32), name="B") - node = helper.make_node("MatMul", ["A", "B"], ["Y"]) - graph = helper.make_graph([node], "matmul_vector_vector_dot", [A], [Y], initializer=[B]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - save_model(model, "matmul/vector_vector_dot", "matmul_vector_vector_dot.onnx") - - -def matmul_batched_4d_broadcast(): - """Batched 4D MatMul with broadcast across leading dimensions.""" - A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 1, 3, 4]) - Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 5, 3, 6]) - B = numpy_helper.from_array(np.random.default_rng(98).uniform(-1, 1, (1, 5, 4, 6)).astype(np.float32), name="B") - node = helper.make_node("MatMul", ["A", "B"], ["Y"]) - graph = helper.make_graph([node], "matmul_batched_4d_broadcast", [A], [Y], initializer=[B]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - save_model(model, "matmul/batched_4d_broadcast", "matmul_batched_4d_broadcast.onnx") - - # --------------------------------------------------------------------------- # Pooling tests # --------------------------------------------------------------------------- @@ -1560,17 +1538,6 @@ def add_channel_broadcast_1024(): save_model(model, "add/channel_broadcast_1024", "add_channel_broadcast_1024.onnx") -def add_scalar_runtime(): - """Elementwise Add with a runtime scalar RHS.""" - A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, 1024, 1, 1]) - B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 1, 1, 1]) - Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 1, 1]) - node = helper.make_node("Add", ["A", "B"], ["Y"]) - graph = helper.make_graph([node], "add_scalar_runtime", [A, B], [Y]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - save_model(model, "add/scalar_runtime", "add_scalar_runtime.onnx") - - def add_leading_dimension_broadcast(): """Elementwise Add with trailing-dimension broadcasting.""" A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4]) @@ -1635,17 +1602,6 @@ def mul_channel_broadcast_1024(): save_model(model, "mul/channel_broadcast_1024", "mul_channel_broadcast_1024.onnx") -def mul_scalar_runtime(): - """Elementwise Mul with a runtime scalar RHS.""" - A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, 1024, 1, 1]) - B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 1, 1, 1]) - Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 1, 1]) - node = helper.make_node("Mul", ["A", "B"], ["Y"]) - graph = helper.make_graph([node], "mul_scalar_runtime", [A, B], [Y]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - save_model(model, "mul/scalar_runtime", "mul_scalar_runtime.onnx") - - def mul_leading_dimension_broadcast(): """Elementwise Mul with trailing-dimension broadcasting.""" A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4]) @@ -1721,17 +1677,6 @@ def div_runtime_scalar_rhs(): save_model(model, "div/runtime_scalar_rhs", "div_runtime_scalar_rhs.onnx") -def div_runtime_scalar_lhs(): - """Elementwise Div with a scalar constant numerator.""" - B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 1024, 1, 1]) - Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1024, 1, 1]) - A = numpy_helper.from_array(np.asarray([[[[2.0]]]], dtype=np.float32), name="A") - node = helper.make_node("Div", ["A", "B"], ["Y"]) - graph = helper.make_graph([node], "div_runtime_scalar_lhs", [B], [Y], initializer=[A]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - save_model(model, "div/runtime_scalar_lhs", "div_runtime_scalar_lhs.onnx") - - def div_leading_dimension_broadcast(): """Elementwise Div with trailing-dimension broadcasting.""" A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3, 4]) @@ -1812,8 +1757,6 @@ if __name__ == "__main__": matmul_huge_1024() matmul_vector_matrix() matmul_matrix_vector() - matmul_vector_vector_dot() - matmul_batched_4d_broadcast() print("\nGenerating Pooling tests:") maxpool_basic() @@ -1899,7 +1842,6 @@ if __name__ == "__main__": add_broadcast_row() add_after_gemm() add_channel_broadcast_1024() - add_scalar_runtime() add_leading_dimension_broadcast() print("\nGenerating Mul tests:") @@ -1907,7 +1849,6 @@ if __name__ == "__main__": mul_scalar_constant() mul_after_conv() mul_channel_broadcast_1024() - mul_scalar_runtime() mul_leading_dimension_broadcast() print("\nGenerating Div tests:") @@ -1916,7 +1857,6 @@ if __name__ == "__main__": div_after_gemm() div_channel_broadcast_1024() div_runtime_scalar_rhs() - div_runtime_scalar_lhs() div_leading_dimension_broadcast() print("\nDone.") diff --git a/validation/operations/matmul/batched_4d_broadcast/matmul_batched_4d_broadcast.onnx b/validation/operations/matmul/batched_4d_broadcast/matmul_batched_4d_broadcast.onnx deleted file mode 100644 index 25382b8..0000000 Binary files a/validation/operations/matmul/batched_4d_broadcast/matmul_batched_4d_broadcast.onnx and /dev/null differ diff --git a/validation/operations/matmul/vector_vector_dot/matmul_vector_vector_dot.onnx b/validation/operations/matmul/vector_vector_dot/matmul_vector_vector_dot.onnx deleted file mode 100644 index cf96880..0000000 Binary files a/validation/operations/matmul/vector_vector_dot/matmul_vector_vector_dot.onnx and /dev/null differ diff --git a/validation/operations/mul/scalar_runtime/mul_scalar_runtime.onnx b/validation/operations/mul/scalar_runtime/mul_scalar_runtime.onnx deleted file mode 100644 index 55e63e0..0000000 Binary files a/validation/operations/mul/scalar_runtime/mul_scalar_runtime.onnx and /dev/null differ