#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static FailureOr> inferSupportedBatchShape(ArrayRef lhsBatchShape, ArrayRef rhsBatchShape) { const int64_t resultRank = std::max(lhsBatchShape.size(), rhsBatchShape.size()); SmallVector resultShape(resultRank, 1); for (int64_t resultIndex = resultRank - 1, lhsIndex = lhsBatchShape.size() - 1, rhsIndex = rhsBatchShape.size() - 1; resultIndex >= 0; --resultIndex, --lhsIndex, --rhsIndex) { const int64_t lhsDim = lhsIndex >= 0 ? lhsBatchShape[lhsIndex] : 1; const int64_t rhsDim = rhsIndex >= 0 ? rhsBatchShape[rhsIndex] : 1; if (lhsDim != rhsDim && lhsDim != 1 && rhsDim != 1) return failure(); resultShape[resultIndex] = std::max(lhsDim, rhsDim); } return resultShape; } static int64_t mapStaticBroadcastedBatchIndex(int64_t outputBatchIndex, ArrayRef sourceBatchShape, ArrayRef outputBatchShape) { if (sourceBatchShape.empty() || getStaticShapeElementCount(sourceBatchShape) == 1) return 0; if (llvm::equal(sourceBatchShape, outputBatchShape)) return outputBatchIndex; SmallVector outputStrides = computeRowMajorStrides(outputBatchShape); SmallVector sourceStrides = computeRowMajorStrides(sourceBatchShape); int64_t sourceFlatIndex = 0; for (int64_t sourceDimIndex = 0; sourceDimIndex < static_cast(sourceBatchShape.size()); ++sourceDimIndex) { if (sourceBatchShape[sourceDimIndex] == 1) continue; const int64_t outputDimIndex = outputBatchShape.size() - sourceBatchShape.size() + sourceDimIndex; const int64_t outputDimStride = outputStrides.empty() ? 1 : outputStrides[outputDimIndex]; const int64_t outputDimIndexValue = outputDimStride == 1 ? outputBatchIndex % outputBatchShape[outputDimIndex] : (outputBatchIndex / outputDimStride) % outputBatchShape[outputDimIndex]; sourceFlatIndex += outputDimIndexValue * sourceStrides[sourceDimIndex]; } return sourceFlatIndex; } static Value computeFlatBatchIndexCoordinate( Value flatBatchIndex, ArrayRef batchShape, int64_t dimIndex, PatternRewriter& rewriter, Location loc) { if (batchShape[dimIndex] == 1) return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); const int64_t dimStride = dimIndex + 1 == static_cast(batchShape.size()) ? 1 : getStaticShapeElementCount(batchShape.drop_front(dimIndex + 1)); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value dimCoordinate = flatBatchIndex; if (dimStride != 1) dimCoordinate = affineFloorDivConst(rewriter, loc, dimCoordinate, dimStride, anchorOp); return affineModConst(rewriter, loc, dimCoordinate, batchShape[dimIndex], anchorOp); } static Value mapOutputBatchIndexToSourceBatchIndex(Value outputBatchIndex, ArrayRef sourceBatchShape, ArrayRef outputBatchShape, PatternRewriter& rewriter, Location loc) { if (sourceBatchShape.empty() || getStaticShapeElementCount(sourceBatchShape) == 1) return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); if (llvm::equal(sourceBatchShape, outputBatchShape)) return outputBatchIndex; Value sourceBatchIndex = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); SmallVector sourceStrides = computeRowMajorStrides(sourceBatchShape); for (int64_t sourceDimIndex = 0; sourceDimIndex < static_cast(sourceBatchShape.size()); ++sourceDimIndex) { if (sourceBatchShape[sourceDimIndex] == 1) continue; const int64_t outputDimIndex = outputBatchShape.size() - sourceBatchShape.size() + sourceDimIndex; Value outputCoordinate = computeFlatBatchIndexCoordinate(outputBatchIndex, outputBatchShape, outputDimIndex, rewriter, loc); Value contribution = sourceStrides[sourceDimIndex] == 1 ? outputCoordinate : affineMulConst(rewriter, loc, outputCoordinate, sourceStrides[sourceDimIndex], rewriter.getInsertionBlock()->getParentOp()); sourceBatchIndex = arith::AddIOp::create(rewriter, loc, sourceBatchIndex, contribution); } return sourceBatchIndex; } static Value collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) { auto type = cast(value.getType()); if (type.getRank() == 2 || type.getRank() == 3) return value; auto collapsedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding()); SmallVector reassociation = {ReassociationIndices {}, ReassociationIndices {static_cast(type.getRank() - 2)}, ReassociationIndices {static_cast(type.getRank() - 1)}}; for (int64_t dim = 0; dim < type.getRank() - 2; ++dim) reassociation.front().push_back(dim); auto buildCollapsed = [&](Value input) -> Value { return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation); }; return materializeOrComputeUnary(value, collapsedType, rewriter, loc, buildCollapsed); } static Value expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, PatternRewriter& rewriter, Location loc) { if (cast(value.getType()) == outputType) return value; SmallVector reassociation = {ReassociationIndices {}, ReassociationIndices {static_cast(batchRank)}, ReassociationIndices {static_cast(batchRank + 1)}}; for (size_t dim = 0; dim < batchRank; ++dim) reassociation.front().push_back(static_cast(dim)); auto buildExpanded = [&](Value input) -> Value { return tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation).getResult(); }; return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded); } static Value createMatrixFromVector(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) { auto buildExpanded = [&](Value input) -> Value { return tensor::ExpandShapeOp::create(rewriter, loc, resultType, input, SmallVector { {0, 1} }); }; return materializeOrComputeUnary(value, resultType, rewriter, loc, buildExpanded); } static SmallVector buildCollapseReassociation(ArrayRef removedAxes) { SmallVector reassociation; ReassociationIndices currentGroup; for (auto [axis, removeAxis] : llvm::enumerate(removedAxes)) { currentGroup.push_back(axis); if (!removeAxis) { reassociation.push_back(currentGroup); currentGroup.clear(); } } if (!currentGroup.empty()) { if (reassociation.empty()) reassociation.push_back(std::move(currentGroup)); else reassociation.back().append(currentGroup.begin(), currentGroup.end()); } return reassociation; } static Value squeezeUnitDims( Value value, RankedTensorType resultType, ArrayRef removedAxes, PatternRewriter& rewriter, Location loc) { if (cast(value.getType()) == resultType) return value; SmallVector reassociation = resultType.getRank() == 0 ? SmallVector {} : buildCollapseReassociation(removedAxes); auto buildCollapsed = [&](Value input) -> Value { return tensor::CollapseShapeOp::create(rewriter, loc, resultType, input, reassociation).getResult(); }; return materializeOrComputeUnary(value, resultType, rewriter, loc, buildCollapsed); } static Value ensureBatchedTensor( Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) { auto type = cast(value.getType()); if (type.getRank() == 3) return value; auto batchedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding()); auto buildExpanded = [&](Value input) -> Value { return tensor::ExpandShapeOp::create(rewriter, loc, batchedType, input, SmallVector { {0, 1}, {2} }); }; return materializeOrComputeUnary(value, batchedType, rewriter, loc, buildExpanded); } static Value extractBatchMatrix(Value value, int64_t batchIndex, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) { auto type = cast(value.getType()); if (type.getRank() == 2) return value; auto sliceType = RankedTensorType::get({1, rows, cols}, type.getElementType()); SmallVector offsets = { rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector sizes = { rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)}; SmallVector strides = getUnitStrides(rewriter, 3); auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType()); auto buildMatrix = [&](Value input) -> Value { Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides); return tensor::CollapseShapeOp::create(rewriter, loc, matrixType, slice, SmallVector { {0, 1}, {2} }); }; return materializeOrComputeUnary(value, matrixType, rewriter, loc, buildMatrix); } static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) { auto type = cast(value.getType()); auto shape = type.getShape(); auto createONNXTranspose = [&](RankedTensorType resultType, ArrayRef permutation) { return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation)).getResult(); }; if (type.getRank() == 2) { auto resultType = RankedTensorType::get({shape[1], shape[0]}, type.getElementType(), type.getEncoding()); return createONNXTranspose(resultType, {1, 0}); } auto resultType = RankedTensorType::get({shape[0], shape[2], shape[1]}, type.getElementType(), type.getEncoding()); return createONNXTranspose(resultType, {0, 2, 1}); } static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) { auto sourceType = cast(value.getType()); SmallVector lowPads(sourceType.getRank(), rewriter.getIndexAttr(0)); SmallVector highPads; highPads.reserve(sourceType.getRank()); for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape())) highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim)); auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads); auto* padBlock = new Block(); for (int64_t i = 0; i < sourceType.getRank(); ++i) padBlock->addArgument(rewriter.getIndexType(), loc); padOp.getRegion().push_back(padBlock); rewriter.setInsertionPointToStart(padBlock); auto zero = getOrCreateConstant( rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType()); tensor::YieldOp::create(rewriter, loc, zero); rewriter.setInsertionPointAfter(padOp); return padOp.getResult(); } static Value createPaddedBatchedInputCompute(Value input, RankedTensorType paddedInputType, PatternRewriter& rewriter, Location loc) { auto inputType = cast(input.getType()); if (inputType == paddedInputType) return input; auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) { Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc); spatial::SpatYieldOp::create(rewriter, loc, paddedInput); }); return computeOp.getResult(0); } static FailureOr materializePaddedBatchedWeight(Value value, ArrayRef sourceBatchShape, ArrayRef targetBatchShape, RankedTensorType resultType, PatternRewriter& rewriter) { auto sourceType = cast(value.getType()); if (sourceType == resultType) return value; auto denseAttr = getHostConstDenseElementsAttr(value); if (!denseAttr) return failure(); const int64_t sourceRows = sourceType.getRank() == 2 ? sourceType.getDimSize(0) : sourceType.getDimSize(1); const int64_t sourceCols = sourceType.getRank() == 2 ? sourceType.getDimSize(1) : sourceType.getDimSize(2); const int64_t targetBatch = targetBatchShape.empty() ? 1 : getStaticShapeElementCount(targetBatchShape); const int64_t targetRows = resultType.getDimSize(1); const int64_t targetCols = resultType.getDimSize(2); SmallVector sourceValues(denseAttr.getValues()); SmallVector resultValues(resultType.getNumElements(), rewriter.getZeroAttr(resultType.getElementType())); for (int64_t batchIdx = 0; batchIdx < targetBatch; ++batchIdx) { const int64_t sourceBatchIdx = sourceType.getRank() == 2 ? 0 : mapStaticBroadcastedBatchIndex(batchIdx, sourceBatchShape, targetBatchShape); const int64_t sourceBatchBase = sourceType.getRank() == 2 ? 0 : sourceBatchIdx * sourceRows * sourceCols; const int64_t targetBatchBase = batchIdx * targetRows * targetCols; for (int64_t row = 0; row < sourceRows; ++row) for (int64_t col = 0; col < sourceCols; ++col) resultValues[targetBatchBase + row * targetCols + col] = sourceValues[sourceBatchBase + row * sourceCols + col]; } auto resultAttr = DenseElementsAttr::get(resultType, resultValues); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType); } static Value extractBatchedATile(Value a, ArrayRef sourceBatchShape, ArrayRef outputBatchShape, Value outputBatchIndex, Value row, Value kOffset, RankedTensorType aTileType, PatternRewriter& rewriter, Location loc) { auto aSliceType = RankedTensorType::get({1, 1, aTileType.getDimSize(1)}, aTileType.getElementType()); Value sourceBatchIndex = mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc); SmallVector offsets {OpFoldResult(sourceBatchIndex), row, kOffset}; SmallVector sizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(aTileType.getDimSize(1))}; auto slice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, getUnitStrides(rewriter, 3)); return tensor::CollapseShapeOp::create(rewriter, loc, aTileType, slice, SmallVector { {0, 1}, {2} }); } static Value extractBatchedBTile(Value b, ArrayRef sourceBatchShape, ArrayRef outputBatchShape, Value outputBatchIndex, Value kOffset, Value hOffset, RankedTensorType bTileType, PatternRewriter& rewriter, Location loc) { auto bSliceType = RankedTensorType::get({1, bTileType.getDimSize(0), bTileType.getDimSize(1)}, bTileType.getElementType()); Value sourceBatchIndex = mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc); SmallVector offsets {OpFoldResult(sourceBatchIndex), kOffset, hOffset}; SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(bTileType.getDimSize(0)), rewriter.getIndexAttr(bTileType.getDimSize(1))}; auto slice = tensor::ExtractSliceOp::create(rewriter, loc, bSliceType, b, offsets, sizes, getUnitStrides(rewriter, 3)); return tensor::CollapseShapeOp::create(rewriter, loc, bTileType, slice, SmallVector { {0, 1}, {2} }); } static Value getBatchLaneIndex( Value lane, int64_t numOutRows, int64_t numKSlices, int64_t numOutHSlices, PatternRewriter& rewriter, Location loc) { return affineFloorDivConst( rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices, rewriter.getInsertionBlock()->getParentOp()); } static FailureOr createBatchedVmmBatch(Value a, Value b, RankedTensorType aType, ArrayRef aBatchShape, RankedTensorType bType, ArrayRef bBatchShape, ArrayRef outputBatchShape, RankedTensorType partialPiecesType, int64_t numOutRows, int64_t numKSlices, int64_t numOutHSlices, PatternRewriter& rewriter, Location loc) { const int64_t laneCount = partialPiecesType.getDimSize(0); auto batchOp = createSpatComputeBatch( rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) { Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value row = affineModConst(rewriter, loc, args.lane, numOutRows, anchorOp); Value outerLane = affineFloorDivConst(rewriter, loc, args.lane, numOutRows, anchorOp); Value batch = getBatchLaneIndex(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc); Value sliceLane = affineModConst(rewriter, loc, outerLane, numKSlices * numOutHSlices, anchorOp); Value kSlice = affineModConst(rewriter, loc, sliceLane, numKSlices, anchorOp); Value hSlice = affineFloorDivConst(rewriter, loc, sliceLane, numKSlices, anchorOp); Value kOffset = affineMulConst(rewriter, loc, kSlice, crossbarSize.getValue(), anchorOp); Value hOffset = affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), anchorOp); auto aTileType = RankedTensorType::get({1, static_cast(crossbarSize.getValue())}, aType.getElementType()); auto bTileType = RankedTensorType::get( {static_cast(crossbarSize.getValue()), static_cast(crossbarSize.getValue())}, bType.getElementType()); auto pieceType = RankedTensorType::get({1, static_cast(crossbarSize.getValue())}, partialPiecesType.getElementType()); Value aTile = extractBatchedATile( args.inputs.front(), aBatchShape, outputBatchShape, batch, row, kOffset, aTileType, rewriter, loc); Value bTile = extractBatchedBTile( args.weights.front(), bBatchShape, outputBatchShape, batch, kOffset, hOffset, bTileType, rewriter, loc); Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult(); SmallVector pieceOffsets {args.lane, rewriter.getIndexAttr(0)}; SmallVector pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())}; createParallelInsertSliceIntoBatchOutput( rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, getUnitStrides(rewriter, 2)); }); if (failed(batchOp)) return failure(); return *batchOp; } static Value extractDynamicBatchedBColumn(Value matrix, ArrayRef sourceBatchShape, ArrayRef outputBatchShape, Value outputBatchIndex, Value column, RankedTensorType vectorType, PatternRewriter& rewriter, Location loc) { auto columnSliceType = RankedTensorType::get({1, vectorType.getDimSize(1), 1}, vectorType.getElementType()); Value sourceBatchIndex = mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc); SmallVector offsets {OpFoldResult(sourceBatchIndex), rewriter.getIndexAttr(0), column}; SmallVector sizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)}; SmallVector strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value columnSlice = tensor::ExtractSliceOp::create(rewriter, loc, columnSliceType, matrix, offsets, sizes, strides); auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType()); Value collapsed = tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, columnSlice, SmallVector { {0, 1, 2} }) .getResult(); return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, SmallVector { {0, 1} }) .getResult(); } static Value extractDynamicBatchedRowVector(Value matrix, ArrayRef sourceBatchShape, ArrayRef outputBatchShape, Value outputBatchIndex, Value row, RankedTensorType vectorType, PatternRewriter& rewriter, Location loc) { auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType()); Value sourceBatchIndex = mapOutputBatchIndexToSourceBatchIndex(outputBatchIndex, sourceBatchShape, outputBatchShape, rewriter, loc); SmallVector offsets {OpFoldResult(sourceBatchIndex), row, rewriter.getIndexAttr(0)}; SmallVector sizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))}; auto rowSlice = tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3)); return tensor::CollapseShapeOp::create(rewriter, loc, vectorType, rowSlice, SmallVector { {0, 1}, {2} }); } static FailureOr createBatchedVvdmulBatch(Value a, ArrayRef aBatchShape, Value b, ArrayRef bBatchShape, ArrayRef outputBatchShape, RankedTensorType aType, RankedTensorType bType, RankedTensorType scalarPiecesType, RankedTensorType outType, PatternRewriter& rewriter, Location loc) { const int64_t numBatches = outType.getDimSize(0); const int64_t numOutRows = outType.getDimSize(1); const int64_t numOutCols = outType.getDimSize(2); const int64_t reductionSize = aType.getDimSize(2); const int64_t laneCount = numBatches * numOutRows * numOutCols; auto batchOp = createSpatComputeBatch( rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) { Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value batch = affineFloorDivConst(rewriter, loc, args.lane, numOutRows * numOutCols, anchorOp); Value batchLane = affineModConst(rewriter, loc, args.lane, numOutRows * numOutCols, anchorOp); Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp); Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp); auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType()); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); Value aVector = extractDynamicBatchedRowVector( args.inputs[0], aBatchShape, outputBatchShape, batch, row, vectorType, rewriter, loc); Value bVector = extractDynamicBatchedBColumn( args.inputs[1], bBatchShape, outputBatchShape, batch, column, vectorType, rewriter, loc); Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult(); SmallVector outputOffsets {args.lane, rewriter.getIndexAttr(0)}; SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; createParallelInsertSliceIntoBatchOutput( rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, getUnitStrides(rewriter, 2)); }); if (failed(batchOp)) return failure(); return *batchOp; } static FailureOr createBatchedDynamicOutputCompute(Value scalarPieces, RankedTensorType scalarPiecesType, RankedTensorType outType, PatternRewriter& rewriter, Location loc) { const int64_t laneCount = scalarPiecesType.getDimSize(0); const int64_t numOutRows = outType.getDimSize(1); const int64_t numOutCols = outType.getDimSize(2); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); auto outputScalarType = RankedTensorType::get({1, 1, 1}, outType.getElementType()); auto computeOp = createSpatCompute<1>( rewriter, loc, TypeRange {outType}, {}, ValueRange {scalarPieces}, [&](Value pieces) -> LogicalResult { Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult(); Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount); auto loop = buildNormalizedScfFor( rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit}, [&](OpBuilder&, Location nestedLoc, Value lane, ValueRange iterArgs, SmallVectorImpl& yielded) { Value outputAcc = iterArgs.front(); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value batch = affineFloorDivConst(rewriter, nestedLoc, lane, numOutRows * numOutCols, anchorOp); Value batchLane = affineModConst(rewriter, nestedLoc, lane, numOutRows * numOutCols, anchorOp); Value row = affineFloorDivConst(rewriter, nestedLoc, batchLane, numOutCols, anchorOp); Value column = affineModConst(rewriter, nestedLoc, batchLane, numOutCols, anchorOp); SmallVector scalarOffsets {lane, rewriter.getIndexAttr(0)}; SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value scalar = tensor::ExtractSliceOp::create( rewriter, nestedLoc, scalarType, pieces, scalarOffsets, scalarSizes, getUnitStrides(rewriter, 2)); Value expanded = tensor::ExpandShapeOp::create(rewriter, nestedLoc, outputScalarType, scalar, SmallVector { {0}, {1, 2} }); SmallVector outputOffsets {batch, row, column}; SmallVector outputSizes = { rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value next = tensor::InsertSliceOp::create( rewriter, nestedLoc, expanded, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3)) .getResult(); yielded.push_back(next); return success(); }); if (failed(loop)) return failure(); spatial::SpatYieldOp::create(rewriter, loc, loop->results.front()); return success(); }); if (failed(computeOp)) return failure(); return computeOp->getResult(0); } static Value extractBatchedReductionPiece(Value partialPiecesArg, Value batch, Value hSlice, int64_t kSlice, RankedTensorType pieceType, int64_t numKSlices, int64_t numOutHSlices, int64_t numOutRows, PatternRewriter& rewriter, Location loc) { Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); Value batchOffset = affineMulConst(rewriter, loc, batch, numOutRows * numKSlices * numOutHSlices, anchorOp); Value hOffset = affineMulConst(rewriter, loc, hSlice, numKSlices * numOutRows, anchorOp); Value kOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice * numOutRows); Value batchAndHSlice = arith::AddIOp::create(rewriter, loc, batchOffset, hOffset); Value pieceOffset = arith::AddIOp::create(rewriter, loc, batchAndHSlice, kOffset); SmallVector offsets {pieceOffset, rewriter.getIndexAttr(0)}; SmallVector sizes {rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())}; return tensor::ExtractSliceOp::create( rewriter, loc, pieceType, partialPiecesArg, offsets, sizes, getUnitStrides(rewriter, 2)); } static Value reduceBatchedPartialPiecesForHSlice(Value partialPiecesArg, Value batch, Value hSlice, RankedTensorType pieceType, int64_t numKSlices, int64_t numOutHSlices, int64_t numOutRows, PatternRewriter& rewriter, Location loc) { SmallVector activePieces; activePieces.reserve(numKSlices); for (int64_t kSlice = 0; kSlice < numKSlices; ++kSlice) activePieces.push_back(extractBatchedReductionPiece( partialPiecesArg, batch, hSlice, kSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, loc)); while (activePieces.size() > 1) { SmallVector nextPieces; nextPieces.reserve((activePieces.size() + 1) / 2); for (size_t pieceIndex = 0; pieceIndex + 1 < activePieces.size(); pieceIndex += 2) nextPieces.push_back( spatial::SpatVAddOp::create(rewriter, loc, pieceType, activePieces[pieceIndex], activePieces[pieceIndex + 1]) .getResult()); if (activePieces.size() % 2 != 0) nextPieces.push_back(activePieces.back()); activePieces = std::move(nextPieces); } return activePieces.front(); } static FailureOr createBatchedReductionCompute(Value partialPieces, RankedTensorType partialPiecesType, RankedTensorType outType, RankedTensorType paddedOutType, int64_t numBatches, int64_t numKSlices, PatternRewriter& rewriter, Location loc) { auto computeOp = createSpatCompute<1>( rewriter, loc, TypeRange {outType}, {}, ValueRange {partialPieces}, [&](Value partialPiecesArg) -> LogicalResult { const int64_t numOutRows = outType.getDimSize(1); const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(2), crossbarSize.getValue()); auto pieceType = RankedTensorType::get({numOutRows, static_cast(crossbarSize.getValue())}, partialPiecesType.getElementType()); auto outputSliceType = RankedTensorType::get({1, numOutRows, static_cast(crossbarSize.getValue())}, partialPiecesType.getElementType()); Value outputInit = tensor::EmptyOp::create(rewriter, loc, paddedOutType.getShape(), paddedOutType.getElementType()).getResult(); Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); Value cNumBatches = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numBatches); Value cNumOutHSlices = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices); auto batchLoop = buildNormalizedScfFor( rewriter, loc, c0, cNumBatches, c1, ValueRange {outputInit}, [&]( OpBuilder&, Location batchLoc, Value batch, ValueRange batchIterArgs, SmallVectorImpl& batchYielded) { auto hLoop = buildNormalizedScfFor( rewriter, batchLoc, c0, cNumOutHSlices, c1, ValueRange {batchIterArgs.front()}, [&](OpBuilder&, Location hLoc, Value hSlice, ValueRange hIterArgs, SmallVectorImpl& hYielded) { Value outputAcc = hIterArgs.front(); Value reduced = reduceBatchedPartialPiecesForHSlice( partialPiecesArg, batch, hSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, hLoc); Value expandedReduced = tensor::ExpandShapeOp::create(rewriter, hLoc, outputSliceType, reduced, SmallVector { {0, 1}, {2} }); Value hOffset = affineMulConst( rewriter, hLoc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp()); SmallVector outputOffsets {batch, rewriter.getIndexAttr(0), hOffset}; SmallVector outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())}; Value next = tensor::InsertSliceOp::create( rewriter, hLoc, expandedReduced, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3)) .getResult(); hYielded.push_back(next); return success(); }); if (failed(hLoop)) return failure(); batchYielded.push_back(hLoop->results.front()); return success(); }); if (failed(batchLoop)) return failure(); Value paddedOutput = batchLoop->results.front(); Value result = paddedOutput; if (paddedOutType != outType) { SmallVector outputOffsets { rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector outputSizes {rewriter.getIndexAttr(numBatches), rewriter.getIndexAttr(outType.getDimSize(1)), rewriter.getIndexAttr(outType.getDimSize(2))}; result = tensor::ExtractSliceOp::create( rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, getUnitStrides(rewriter, 3)); } spatial::SpatYieldOp::create(rewriter, loc, result); return success(); }); if (failed(computeOp)) return failure(); return computeOp->getResult(0); } struct NormalizedMatMulInfo { RankedTensorType lhsType; RankedTensorType rhsType; RankedTensorType outType; RankedTensorType normalizedLhsType; RankedTensorType normalizedRhsType; SmallVector lhsBatchShape; SmallVector rhsBatchShape; SmallVector outputBatchShape; bool lhsWasVector; bool rhsWasVector; int64_t lhsBatch; int64_t rhsBatch; int64_t batch; int64_t m; int64_t k; int64_t n; }; struct MatMulLoweringPlan { Value lhs; Value rhs; RankedTensorType lhsType; RankedTensorType rhsType; SmallVector lhsBatchShape; SmallVector rhsBatchShape; SmallVector outputBatchShape; int64_t lhsBatch; int64_t rhsBatch; int64_t batch; int64_t m; int64_t k; int64_t n; bool transposedResult; }; static SmallVector computeExpectedMatMulOutputShape( ArrayRef batchShape, int64_t m, int64_t n, bool lhsWasVector, bool rhsWasVector) { SmallVector shape(batchShape.begin(), batchShape.end()); if (lhsWasVector && rhsWasVector) return shape; if (lhsWasVector) { shape.push_back(n); return shape; } if (rhsWasVector) { shape.push_back(m); return shape; } shape.push_back(m); shape.push_back(n); return shape; } static FailureOr analyzeMatMulShape(ONNXMatMulOp matmulOp) { auto lhsType = dyn_cast(matmulOp.getA().getType()); auto rhsType = dyn_cast(matmulOp.getB().getType()); auto outType = dyn_cast(matmulOp.getY().getType()); if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape() || !outType.hasStaticShape()) return failure(); if (lhsType.getRank() < 1 || rhsType.getRank() < 1) return failure(); if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType)) return failure(); const bool lhsWasVector = lhsType.getRank() == 1; const bool rhsWasVector = rhsType.getRank() == 1; auto normalizedLhsType = lhsWasVector ? RankedTensorType::get({1, lhsType.getDimSize(0)}, lhsType.getElementType(), lhsType.getEncoding()) : lhsType; auto normalizedRhsType = rhsWasVector ? RankedTensorType::get({rhsType.getDimSize(0), 1}, rhsType.getElementType(), rhsType.getEncoding()) : rhsType; SmallVector lhsBatchShape(normalizedLhsType.getShape().begin(), normalizedLhsType.getShape().end() - 2); SmallVector rhsBatchShape(normalizedRhsType.getShape().begin(), normalizedRhsType.getShape().end() - 2); auto outputBatchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape); if (failed(outputBatchShape)) return failure(); const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape); const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape); const int64_t batch = outputBatchShape->empty() ? 1 : getStaticShapeElementCount(*outputBatchShape); const int64_t m = normalizedLhsType.getDimSize(normalizedLhsType.getRank() - 2); const int64_t k = normalizedLhsType.getDimSize(normalizedLhsType.getRank() - 1); const int64_t rhsK = normalizedRhsType.getDimSize(normalizedRhsType.getRank() - 2); const int64_t n = normalizedRhsType.getDimSize(normalizedRhsType.getRank() - 1); if (k != rhsK) return failure(); if (SmallVector(outType.getShape().begin(), outType.getShape().end()) != computeExpectedMatMulOutputShape(*outputBatchShape, m, n, lhsWasVector, rhsWasVector)) { return failure(); } return NormalizedMatMulInfo {lhsType, rhsType, outType, normalizedLhsType, normalizedRhsType, lhsBatchShape, rhsBatchShape, *outputBatchShape, lhsWasVector, rhsWasVector, lhsBatch, rhsBatch, batch, m, k, n}; } static MatMulLoweringPlan buildLoweringPlan(Value normalizedLhs, Value normalizedRhs, const NormalizedMatMulInfo& info, bool useTransposedForm, PatternRewriter& rewriter, Location loc) { MatMulLoweringPlan plan {normalizedLhs, normalizedRhs, cast(normalizedLhs.getType()), cast(normalizedRhs.getType()), info.lhsBatchShape, info.rhsBatchShape, info.outputBatchShape, info.lhsBatch, info.rhsBatch, info.batch, info.m, info.k, info.n, false}; if (!useTransposedForm) return plan; plan.lhs = transposeLastTwoDims(normalizedRhs, rewriter, loc); plan.rhs = transposeLastTwoDims(normalizedLhs, rewriter, loc); plan.lhsType = cast(plan.lhs.getType()); plan.rhsType = cast(plan.rhs.getType()); std::swap(plan.lhsBatchShape, plan.rhsBatchShape); std::swap(plan.lhsBatch, plan.rhsBatch); plan.m = info.n; plan.n = info.m; plan.transposedResult = true; return plan; } static Value normalizeMatMulOperand( Value value, RankedTensorType normalizedType, bool wasVector, PatternRewriter& rewriter, Location loc) { if (!wasVector) return value; return createMatrixFromVector(value, normalizedType, rewriter, loc); } static Value finalizeNormalizedMatMulResult(Value value, RankedTensorType directOutType, const NormalizedMatMulInfo& info, PatternRewriter& rewriter, Location loc) { // The direct lowered result is always [flatBatch, normalizedM, normalizedN]. // Restore ONNX MatMul result rank by expanding right-aligned batch dimensions // and removing the synthetic unit matrix axes introduced for vector operands. Value result = value; RankedTensorType currentType = directOutType; if (info.outputBatchShape.size() > 1) { SmallVector expandedShape(info.outputBatchShape.begin(), info.outputBatchShape.end()); expandedShape.push_back(info.m); expandedShape.push_back(info.n); auto expandedType = RankedTensorType::get(expandedShape, info.outType.getElementType(), info.outType.getEncoding()); result = expandBatchDims(result, expandedType, info.outputBatchShape.size(), rewriter, loc); currentType = expandedType; } SmallVector removedAxes(currentType.getRank(), false); if (info.outputBatchShape.empty()) removedAxes[0] = true; if (info.lhsWasVector) removedAxes[currentType.getRank() - 2] = true; if (info.rhsWasVector) removedAxes[currentType.getRank() - 1] = true; return squeezeUnitDims(result, info.outType, removedAxes, rewriter, loc); } struct MatMulToGemm : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override { auto shapeInfo = analyzeMatMulShape(matmulOp); if (failed(shapeInfo) || shapeInfo->lhsWasVector || shapeInfo->rhsWasVector || !shapeInfo->outputBatchShape.empty()) return failure(); Location loc = matmulOp.getLoc(); bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB()); Value lhs = collapseBatchDims(matmulOp.getA(), shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc); Value rhs = collapseBatchDims(matmulOp.getB(), shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc); int64_t lhsBatchForGemm = shapeInfo->lhsBatch; int64_t rhsBatchForGemm = shapeInfo->rhsBatch; int64_t gemmM = shapeInfo->m; int64_t gemmK = shapeInfo->k; int64_t gemmN = shapeInfo->n; if (useTransposedForm) { lhs = transposeLastTwoDims(matmulOp.getB(), rewriter, loc); lhsBatchForGemm = shapeInfo->rhsBatch; rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc); rhsBatchForGemm = shapeInfo->lhsBatch; gemmM = shapeInfo->n; gemmN = shapeInfo->m; } auto gemmType = RankedTensorType::get({gemmM, gemmN}, shapeInfo->outType.getElementType()); Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc); Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc); Value gemmResult = ONNXGemmOp::create(rewriter, loc, gemmType, lhsMatrix, rhsMatrix, none, rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f), rewriter.getBoolAttr(false), rewriter.getBoolAttr(false)) .getY(); if (useTransposedForm) gemmResult = ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, gemmResult, rewriter.getI64ArrayAttr({1, 0})) .getResult(); rewriter.replaceOp(matmulOp, gemmResult); return success(); } }; struct MatMulBatchedToSpatialComputes : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override { auto shapeInfo = analyzeMatMulShape(matmulOp); if (failed(shapeInfo)) return failure(); if (!shapeInfo->lhsWasVector && !shapeInfo->rhsWasVector && shapeInfo->outputBatchShape.empty()) return failure(); Location loc = matmulOp.getLoc(); bool useTransposedForm = !shapeInfo->lhsWasVector && !shapeInfo->rhsWasVector && isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB()); Value lhs = normalizeMatMulOperand(matmulOp.getA(), shapeInfo->normalizedLhsType, shapeInfo->lhsWasVector, rewriter, loc); Value rhs = normalizeMatMulOperand(matmulOp.getB(), shapeInfo->normalizedRhsType, shapeInfo->rhsWasVector, rewriter, loc); lhs = collapseBatchDims(lhs, shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc); rhs = collapseBatchDims(rhs, shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc); MatMulLoweringPlan plan = buildLoweringPlan(lhs, rhs, *shapeInfo, useTransposedForm, rewriter, loc); plan.lhs = ensureBatchedTensor(plan.lhs, plan.lhsBatch, plan.m, plan.k, rewriter, loc); plan.rhs = ensureBatchedTensor(plan.rhs, plan.rhsBatch, plan.k, plan.n, rewriter, loc); plan.lhsType = cast(plan.lhs.getType()); plan.rhsType = cast(plan.rhs.getType()); auto directOutType = RankedTensorType::get( {plan.batch, plan.m, plan.n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding()); if (isCompileTimeComputable(plan.rhs)) { const int64_t numKSlices = ceilIntegerDivide(plan.k, crossbarSize.getValue()); const int64_t numOutHSlices = ceilIntegerDivide(plan.n, crossbarSize.getValue()); const int64_t paddedReductionSize = numKSlices * static_cast(crossbarSize.getValue()); const int64_t paddedOutCols = numOutHSlices * static_cast(crossbarSize.getValue()); auto paddedLhsType = RankedTensorType::get( {plan.lhsBatch, plan.m, paddedReductionSize}, plan.lhsType.getElementType(), plan.lhsType.getEncoding()); auto paddedRhsType = RankedTensorType::get( {plan.batch, paddedReductionSize, paddedOutCols}, plan.rhsType.getElementType(), plan.rhsType.getEncoding()); auto paddedOutType = RankedTensorType::get({plan.batch, plan.m, paddedOutCols}, shapeInfo->outType.getElementType()); auto paddedRhs = materializePaddedBatchedWeight(plan.rhs, plan.rhsBatchShape, plan.outputBatchShape, paddedRhsType, rewriter); if (succeeded(paddedRhs)) { Value paddedLhs = createPaddedBatchedInputCompute(plan.lhs, paddedLhsType, rewriter, loc); const int64_t laneCount = plan.batch * plan.m * numKSlices * numOutHSlices; auto partialPiecesType = RankedTensorType::get({laneCount, static_cast(crossbarSize.getValue())}, shapeInfo->outType.getElementType()); auto batchOp = createBatchedVmmBatch(paddedLhs, *paddedRhs, paddedLhsType, plan.lhsBatchShape, paddedRhsType, plan.rhsBatchShape, plan.outputBatchShape, partialPiecesType, plan.m, numKSlices, numOutHSlices, rewriter, loc); if (failed(batchOp)) return failure(); auto result = createBatchedReductionCompute(batchOp->getResult(0), partialPiecesType, directOutType, paddedOutType, plan.batch, numKSlices, rewriter, loc); if (failed(result)) return failure(); Value finalResult = *result; if (plan.transposedResult) { auto transposedOutType = RankedTensorType::get({plan.batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding()); finalResult = ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1})) .getResult(); } finalResult = finalizeNormalizedMatMulResult(finalResult, directOutType, *shapeInfo, rewriter, loc); rewriter.replaceOp(matmulOp, finalResult); return success(); } } const int64_t laneCount = plan.batch * plan.m * plan.n; auto scalarPiecesType = RankedTensorType::get({laneCount, 1}, shapeInfo->outType.getElementType()); auto batchOp = createBatchedVvdmulBatch(plan.lhs, plan.lhsBatchShape, plan.rhs, plan.rhsBatchShape, plan.outputBatchShape, plan.lhsType, plan.rhsType, scalarPiecesType, directOutType, rewriter, loc); if (failed(batchOp)) return failure(); auto result = createBatchedDynamicOutputCompute(batchOp->getResult(0), scalarPiecesType, directOutType, rewriter, loc); if (failed(result)) return failure(); Value finalResult = *result; if (plan.transposedResult) { auto transposedOutType = RankedTensorType::get({plan.batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType(), shapeInfo->outType.getEncoding()); finalResult = ONNXTransposeOp::create(rewriter, loc, transposedOutType, finalResult, rewriter.getI64ArrayAttr({0, 2, 1})) .getResult(); } finalResult = finalizeNormalizedMatMulResult(finalResult, directOutType, *shapeInfo, rewriter, loc); rewriter.replaceOp(matmulOp, finalResult); return success(); } }; } // namespace void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } } // namespace onnx_mlir