diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index ca04343..3fa2a59 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -51,23 +51,107 @@ static Value createPaddedRows(Value tensorValue, if (tensorType.getDimSize(0) == paddedRows) return tensorValue; - auto paddedType = RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType()); + auto paddedType = + RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding()); SmallVector lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)), rewriter.getIndexAttr(0)}; auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads); auto* padBlock = new Block(); - for (int i = 0; i < 2; i++) + for (int i = 0; i < 2; ++i) padBlock->addArgument(rewriter.getIndexType(), loc); padOp.getRegion().push_back(padBlock); rewriter.setInsertionPointToStart(padBlock); - auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()), + auto zero = getOrCreateConstant(rewriter, + padOp.getOperation(), + rewriter.getZeroAttr(tensorType.getElementType()), tensorType.getElementType()); tensor::YieldOp::create(rewriter, loc, zero); rewriter.setInsertionPointAfter(padOp); return padOp.getResult(); } +static Value packRowsForParallelGemm(Value rows, + RankedTensorType rowsType, + int64_t packFactor, + ConversionPatternRewriter& rewriter, + Location loc) { + if (packFactor == 1) + return rows; + + const int64_t packedNumRows = ceilIntegerDivide(rowsType.getDimSize(0), packFactor); + const int64_t paddedNumRows = packedNumRows * packFactor; + const int64_t rowWidth = rowsType.getDimSize(1); + auto groupedType = + RankedTensorType::get({packedNumRows, packFactor, rowWidth}, rowsType.getElementType(), rowsType.getEncoding()); + auto packedType = + RankedTensorType::get({packedNumRows, packFactor * rowWidth}, rowsType.getElementType(), rowsType.getEncoding()); + + Value paddedRows = createPaddedRows(rows, rowsType, paddedNumRows, rewriter, loc); + Value groupedRows = tensor::ExpandShapeOp::create(rewriter, + loc, + groupedType, + paddedRows, + SmallVector { + {0, 1}, + {2} + }); + return tensor::CollapseShapeOp::create(rewriter, + loc, + packedType, + groupedRows, + SmallVector { + {0}, + {1, 2} + }); +} + +static Value unpackRowsFromParallelGemm(Value packedRows, + RankedTensorType packedRowsType, + int64_t unpackedRows, + int64_t rowWidth, + int64_t packFactor, + ConversionPatternRewriter& rewriter, + Location loc) { + if (packFactor == 1) + return packedRows; + + const int64_t packedNumRows = packedRowsType.getDimSize(0); + const int64_t paddedNumRows = packedNumRows * packFactor; + auto expandedType = + RankedTensorType::get({packedNumRows, packFactor, rowWidth}, + packedRowsType.getElementType(), + packedRowsType.getEncoding()); + auto paddedType = + RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding()); + auto unpackedType = + RankedTensorType::get({unpackedRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding()); + + Value expandedRows = tensor::ExpandShapeOp::create(rewriter, + loc, + expandedType, + packedRows, + SmallVector { + {0}, + {1, 2} + }); + Value paddedRows = tensor::CollapseShapeOp::create(rewriter, + loc, + paddedType, + expandedRows, + SmallVector { + {0, 1}, + {2} + }); + if (paddedNumRows == unpackedRows) + return paddedRows; + + SmallVector offsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; + SmallVector sizes {rewriter.getIndexAttr(unpackedRows), rewriter.getIndexAttr(rowWidth)}; + SmallVector strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + return tensor::ExtractSliceOp::create(rewriter, loc, unpackedType, paddedRows, offsets, sizes, strides); +} + static Value buildPackedWeight(DenseElementsAttr wDenseAttr, Value wTrans, RankedTensorType wType, @@ -189,7 +273,6 @@ static Value createIm2colRowComputes(Value x, Location loc) { auto elemType = xType.getElementType(); constexpr size_t numInputs = 1; - const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor); auto im2colComputeOp = createSpatCompute(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) { Value paddedInput = xArg; @@ -278,26 +361,7 @@ static Value createIm2colRowComputes(Value x, Value gemmInputRows = im2col; if (packFactor != 1) { - const int64_t paddedNumPatches = packedNumRows * packFactor; - auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType); - auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType); - Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc); - Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter, - loc, - groupedType, - paddedIm2col, - SmallVector { - {0, 1}, - {2} - }); - gemmInputRows = tensor::CollapseShapeOp::create(rewriter, - loc, - packedType, - groupedIm2col, - SmallVector { - {0}, - {1, 2} - }); + gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc); } spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows); @@ -316,41 +380,15 @@ static Value createCollectedConvOutput(ValueRange gemmRows, int64_t packFactor, ConversionPatternRewriter& rewriter, Location loc) { - const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor); - const int64_t paddedNumPatches = packedNumRows * packFactor; auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) { Value gemmOut; if (packFactor == 1) { gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs); } else { - auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType()); - auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType()); Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs); - Value expandedOutput = tensor::ExpandShapeOp::create(rewriter, - loc, - expandedType, - packedOutput, - SmallVector { - {0}, - {1, 2} - }); - Value paddedOutput = tensor::CollapseShapeOp::create(rewriter, - loc, - paddedType, - expandedOutput, - SmallVector { - {0, 1}, - {2} - }); - - gemmOut = paddedOutput; - if (paddedNumPatches != numPatches) { - SmallVector offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; - SmallVector sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)}; - SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides); - } + gemmOut = unpackRowsFromParallelGemm( + packedOutput, cast(packedOutput.getType()), numPatches, numChannelsOut, packFactor, rewriter, loc); } // Restore to NCHW layout: diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index eb7329e..7037125 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -1,3 +1,5 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" @@ -5,9 +7,6 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include -#include - #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" @@ -66,6 +65,26 @@ expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, Patt return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded); } +static Value ensureBatchedTensor( + Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) { + auto type = cast(value.getType()); + if (type.getRank() == 3) + return value; + + auto batchedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding()); + auto buildExpanded = [&](Value input) -> Value { + return tensor::ExpandShapeOp::create(rewriter, + loc, + batchedType, + input, + SmallVector { + {0, 1}, + {2} + }); + }; + return materializeOrComputeUnary(value, batchedType, rewriter, loc, buildExpanded); +} + static Value extractBatchMatrix(Value value, int64_t batchIndex, int64_t batchSize, @@ -130,164 +149,737 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite perm = {0, 2, 1}; } + auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) { + Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm)); + spatial::SpatYieldOp::create(rewriter, loc, transposed); + }); + return transposeCompute.getResult(0); +} + +static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) { + auto sourceType = cast(value.getType()); + SmallVector lowPads(sourceType.getRank(), rewriter.getIndexAttr(0)); + SmallVector highPads; + highPads.reserve(sourceType.getRank()); + for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape())) + highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim)); + + auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads); + auto* padBlock = new Block(); + for (int64_t i = 0; i < sourceType.getRank(); ++i) + padBlock->addArgument(rewriter.getIndexType(), loc); + padOp.getRegion().push_back(padBlock); + rewriter.setInsertionPointToStart(padBlock); + auto zero = getOrCreateConstant( + rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType()); + tensor::YieldOp::create(rewriter, loc, zero); + rewriter.setInsertionPointAfter(padOp); + return padOp.getResult(); +} + +static Value createPaddedBatchedInputCompute(Value input, + RankedTensorType paddedInputType, + PatternRewriter& rewriter, + Location loc) { + auto inputType = cast(input.getType()); + if (inputType == paddedInputType) + return input; + + auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) { + Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc); + spatial::SpatYieldOp::create(rewriter, loc, paddedInput); + }); + return computeOp.getResult(0); +} + +static FailureOr materializePaddedBatchedWeight( + Value value, int64_t sourceBatch, int64_t targetBatch, RankedTensorType resultType, PatternRewriter& rewriter) { + auto sourceType = cast(value.getType()); + if (sourceType == resultType) + return value; + + auto denseAttr = getHostConstDenseElementsAttr(value); + if (!denseAttr) + return failure(); + + const int64_t sourceRows = sourceType.getRank() == 2 ? sourceType.getDimSize(0) : sourceType.getDimSize(1); + const int64_t sourceCols = sourceType.getRank() == 2 ? sourceType.getDimSize(1) : sourceType.getDimSize(2); + const int64_t targetRows = resultType.getDimSize(1); + const int64_t targetCols = resultType.getDimSize(2); + SmallVector sourceValues(denseAttr.getValues()); + SmallVector resultValues(resultType.getNumElements(), rewriter.getZeroAttr(resultType.getElementType())); + + for (int64_t batchIdx = 0; batchIdx < targetBatch; ++batchIdx) { + const int64_t sourceBatchIdx = sourceType.getRank() == 2 ? 0 : (sourceBatch == 1 ? 0 : batchIdx); + const int64_t sourceBatchBase = sourceType.getRank() == 2 ? 0 : sourceBatchIdx * sourceRows * sourceCols; + const int64_t targetBatchBase = batchIdx * targetRows * targetCols; + for (int64_t row = 0; row < sourceRows; ++row) + for (int64_t col = 0; col < sourceCols; ++col) + resultValues[targetBatchBase + row * targetCols + col] = sourceValues[sourceBatchBase + row * sourceCols + col]; + } + + auto resultAttr = DenseElementsAttr::get(resultType, resultValues); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType); +} + +static Value extractBatchedATile(Value a, + int64_t sourceBatchCount, + Value batch, + Value row, + Value kOffset, + RankedTensorType aTileType, + PatternRewriter& rewriter, + Location loc) { + auto aSliceType = RankedTensorType::get({1, 1, aTileType.getDimSize(1)}, aTileType.getElementType()); + SmallVector offsets { + sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), row, kOffset}; + SmallVector sizes { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(aTileType.getDimSize(1))}; + auto slice = + tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, getUnitStrides(rewriter, 3)); + return tensor::CollapseShapeOp::create(rewriter, + loc, + aTileType, + slice, + SmallVector { + {0, 1}, + {2} + }); +} + +static Value extractBatchedBTile(Value b, + int64_t sourceBatchCount, + Value batch, + Value kOffset, + Value hOffset, + RankedTensorType bTileType, + PatternRewriter& rewriter, + Location loc) { + auto bSliceType = + RankedTensorType::get({1, bTileType.getDimSize(0), bTileType.getDimSize(1)}, bTileType.getElementType()); + SmallVector offsets { + sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), kOffset, hOffset}; + SmallVector sizes {rewriter.getIndexAttr(1), + rewriter.getIndexAttr(bTileType.getDimSize(0)), + rewriter.getIndexAttr(bTileType.getDimSize(1))}; + auto slice = + tensor::ExtractSliceOp::create(rewriter, loc, bSliceType, b, offsets, sizes, getUnitStrides(rewriter, 3)); + return tensor::CollapseShapeOp::create(rewriter, + loc, + bTileType, + slice, + SmallVector { + {0, 1}, + {2} + }); +} + +static Value getBatchLaneIndex( + Value lane, int64_t numOutRows, int64_t numKSlices, int64_t numOutHSlices, PatternRewriter& rewriter, Location loc) { + return floorDivIndexByConstant(rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices); +} + +static spatial::SpatComputeBatch createBatchedVmmBatch(Value a, + Value b, + RankedTensorType aType, + int64_t aBatchCount, + RankedTensorType bType, + int64_t bBatchCount, + RankedTensorType partialPiecesType, + int64_t numOutRows, + int64_t numKSlices, + int64_t numOutHSlices, + PatternRewriter& rewriter, + Location loc) { + const int64_t laneCount = partialPiecesType.getDimSize(0); + auto batchOp = createSpatComputeBatch( + rewriter, + loc, + TypeRange {partialPiecesType}, + laneCount, + ValueRange {b}, + ValueRange {a}, + [&](detail::SpatComputeBatchBodyArgs args) { + Value row = modIndexByConstant(rewriter, loc, args.lane, numOutRows); + Value outerLane = floorDivIndexByConstant(rewriter, loc, args.lane, numOutRows); + Value batch = getBatchLaneIndex(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc); + Value sliceLane = modIndexByConstant(rewriter, loc, outerLane, numKSlices * numOutHSlices); + Value kSlice = modIndexByConstant(rewriter, loc, sliceLane, numKSlices); + Value hSlice = floorDivIndexByConstant(rewriter, loc, sliceLane, numKSlices); + Value kOffset = + multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice, crossbarSize.getValue()); + Value hOffset = + multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue()); + + auto aTileType = + RankedTensorType::get({1, static_cast(crossbarSize.getValue())}, aType.getElementType()); + auto bTileType = RankedTensorType::get( + {static_cast(crossbarSize.getValue()), static_cast(crossbarSize.getValue())}, + bType.getElementType()); + auto pieceType = + RankedTensorType::get({1, static_cast(crossbarSize.getValue())}, partialPiecesType.getElementType()); + + Value aTile = + extractBatchedATile(args.inputs.front(), aBatchCount, batch, row, kOffset, aTileType, rewriter, loc); + Value bTile = + extractBatchedBTile(args.weights.front(), bBatchCount, batch, kOffset, hOffset, bTileType, rewriter, loc); + Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult(); + + SmallVector pieceOffsets {args.lane, rewriter.getIndexAttr(0)}; + SmallVector pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())}; + createParallelInsertSliceIntoBatchOutput( + rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, getUnitStrides(rewriter, 2)); + }); + assert(succeeded(batchOp) && "expected batched MatMul VMM construction to succeed"); + return *batchOp; +} + +static Value extractDynamicBatchedBColumn(Value matrix, + int64_t sourceBatchCount, + Value batch, + Value column, + RankedTensorType vectorType, + PatternRewriter& rewriter, + Location loc) { + auto columnSliceType = RankedTensorType::get({1, vectorType.getDimSize(1), 1}, vectorType.getElementType()); + SmallVector offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) + : OpFoldResult(batch), + rewriter.getIndexAttr(0), + column}; + SmallVector sizes { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)}; + SmallVector strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + Value columnSlice = tensor::ExtractSliceOp::create(rewriter, loc, columnSliceType, matrix, offsets, sizes, strides); + auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType()); + Value collapsed = tensor::CollapseShapeOp::create(rewriter, + loc, + collapsedType, + columnSlice, + SmallVector { + {0, 1, 2} + }) + .getResult(); + return tensor::ExpandShapeOp::create(rewriter, + loc, + vectorType, + collapsed, + SmallVector { + {0, 1} + }) + .getResult(); +} + +static Value extractDynamicBatchedBRow(Value matrix, + int64_t sourceBatchCount, + Value batch, + Value row, + RankedTensorType vectorType, + PatternRewriter& rewriter, + Location loc) { + auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType()); + SmallVector offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) + : OpFoldResult(batch), + row, + rewriter.getIndexAttr(0)}; + SmallVector sizes { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))}; + auto rowSlice = + tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3)); + return tensor::CollapseShapeOp::create(rewriter, + loc, + vectorType, + rowSlice, + SmallVector { + {0, 1}, + {2} + }); +} + +static Value extractDynamicBatchedRowVector(Value matrix, + int64_t sourceBatchCount, + Value batch, + Value row, + RankedTensorType vectorType, + PatternRewriter& rewriter, + Location loc) { + auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType()); + SmallVector offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) + : OpFoldResult(batch), + row, + rewriter.getIndexAttr(0)}; + SmallVector sizes { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))}; + auto rowSlice = + tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3)); + return tensor::CollapseShapeOp::create(rewriter, + loc, + vectorType, + rowSlice, + SmallVector { + {0, 1}, + {2} + }); +} + +static spatial::SpatComputeBatch createBatchedVvdmulBatch(Value a, + int64_t aBatchCount, + Value b, + int64_t bBatchCount, + RankedTensorType aType, + RankedTensorType bType, + RankedTensorType scalarPiecesType, + RankedTensorType outType, + bool bAlreadyTransposed, + PatternRewriter& rewriter, + Location loc) { + const int64_t numBatches = outType.getDimSize(0); + const int64_t numOutRows = outType.getDimSize(1); + const int64_t numOutCols = outType.getDimSize(2); + const int64_t reductionSize = aType.getDimSize(2); + const int64_t laneCount = numBatches * numOutRows * numOutCols; + auto batchOp = createSpatComputeBatch( + rewriter, + loc, + TypeRange {scalarPiecesType}, + laneCount, + ValueRange {}, + ValueRange {a, b}, + [&](detail::SpatComputeBatchBodyArgs args) { + Value batch = floorDivIndexByConstant(rewriter, loc, args.lane, numOutRows * numOutCols); + Value batchLane = modIndexByConstant(rewriter, loc, args.lane, numOutRows * numOutCols); + Value row = floorDivIndexByConstant(rewriter, loc, batchLane, numOutCols); + Value column = modIndexByConstant(rewriter, loc, batchLane, numOutCols); + + auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType()); + auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); + Value aVector = + extractDynamicBatchedRowVector(args.inputs[0], aBatchCount, batch, row, vectorType, rewriter, loc); + Value bVector = + bAlreadyTransposed + ? extractDynamicBatchedBRow(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc) + : extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc); + Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult(); + SmallVector outputOffsets {args.lane, rewriter.getIndexAttr(0)}; + SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + createParallelInsertSliceIntoBatchOutput( + rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, getUnitStrides(rewriter, 2)); + }); + assert(succeeded(batchOp) && "expected batched MatMul VVDMul construction to succeed"); + return *batchOp; +} + +static Value createBatchedDynamicOutputCompute(Value scalarPieces, + RankedTensorType scalarPiecesType, + RankedTensorType outType, + PatternRewriter& rewriter, + Location loc) { + const int64_t laneCount = scalarPiecesType.getDimSize(0); + const int64_t numOutRows = outType.getDimSize(1); + const int64_t numOutCols = outType.getDimSize(2); + auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); + auto outputScalarType = RankedTensorType::get({1, 1, 1}, outType.getElementType()); + + auto computeOp = + createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, ValueRange {scalarPieces}, [&](Value pieces) { + Value outputInit = + tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult(); + Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); + Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount); + auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit}); + rewriter.setInsertionPointToStart(loop.getBody()); + + Value lane = loop.getInductionVar(); + Value outputAcc = loop.getRegionIterArgs().front(); + Value batch = floorDivIndexByConstant(rewriter, loc, lane, numOutRows * numOutCols); + Value batchLane = modIndexByConstant(rewriter, loc, lane, numOutRows * numOutCols); + Value row = floorDivIndexByConstant(rewriter, loc, batchLane, numOutCols); + Value column = modIndexByConstant(rewriter, loc, batchLane, numOutCols); + SmallVector scalarOffsets {lane, rewriter.getIndexAttr(0)}; + SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + Value scalar = tensor::ExtractSliceOp::create( + rewriter, loc, scalarType, pieces, scalarOffsets, scalarSizes, getUnitStrides(rewriter, 2)); + Value expanded = tensor::ExpandShapeOp::create(rewriter, + loc, + outputScalarType, + scalar, + SmallVector { + {0}, + {1, 2} + }); + SmallVector outputOffsets {batch, row, column}; + SmallVector outputSizes { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + scf::YieldOp::create( + rewriter, + loc, + tensor::InsertSliceOp::create( + rewriter, loc, expanded, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3)) + .getResult()); + + rewriter.setInsertionPointAfter(loop); + spatial::SpatYieldOp::create(rewriter, loc, loop.getResult(0)); + }); + return computeOp.getResult(0); +} + +static Value transposeBatchedOutput(Value value, RankedTensorType outputType, PatternRewriter& rewriter, Location loc) { auto transposeCompute = - createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) { - Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm)); + createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) { + Value transposed = ONNXTransposeOp::create(rewriter, loc, outputType, input, rewriter.getI64ArrayAttr({0, 2, 1})); spatial::SpatYieldOp::create(rewriter, loc, transposed); }); return transposeCompute.getResult(0); } -static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) { - auto firstType = cast(inputs.front().getType()); - SmallVector outputShape(firstType.getShape().begin(), firstType.getShape().end()); - int64_t concatDimSize = 0; - for (Value input : inputs) - concatDimSize += cast(input.getType()).getDimSize(axis); - outputShape[axis] = concatDimSize; - auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding()); +static Value extractBatchedReductionPiece(Value partialPiecesArg, + Value batch, + Value hSlice, + int64_t kSlice, + RankedTensorType pieceType, + int64_t numKSlices, + int64_t numOutHSlices, + int64_t numOutRows, + PatternRewriter& rewriter, + Location loc) { + Value batchOffset = multiplyIndexByConstant( + rewriter, rewriter.getInsertionBlock()->getParentOp(), batch, numOutRows * numKSlices * numOutHSlices); + Value hOffset = + multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, numKSlices * numOutRows); + Value kOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice * numOutRows); + Value batchAndHSlice = arith::AddIOp::create(rewriter, loc, batchOffset, hOffset); + Value pieceOffset = arith::AddIOp::create(rewriter, loc, batchAndHSlice, kOffset); + SmallVector offsets {pieceOffset, rewriter.getIndexAttr(0)}; + SmallVector sizes {rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())}; + return tensor::ExtractSliceOp::create( + rewriter, loc, pieceType, partialPiecesArg, offsets, sizes, getUnitStrides(rewriter, 2)); +} - if (llvm::all_of(inputs, isCompileTimeComputable)) - return createSpatConcat(rewriter, loc, axis, inputs); +static Value reduceBatchedPartialPiecesForHSlice(Value partialPiecesArg, + Value batch, + Value hSlice, + RankedTensorType pieceType, + int64_t numKSlices, + int64_t numOutHSlices, + int64_t numOutRows, + PatternRewriter& rewriter, + Location loc) { + SmallVector activePieces; + activePieces.reserve(numKSlices); + for (int64_t kSlice = 0; kSlice < numKSlices; ++kSlice) + activePieces.push_back(extractBatchedReductionPiece( + partialPiecesArg, batch, hSlice, kSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, loc)); - auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) { - spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args)); - }); - return concatCompute.getResult(0); + while (activePieces.size() > 1) { + SmallVector nextPieces; + nextPieces.reserve((activePieces.size() + 1) / 2); + for (size_t pieceIndex = 0; pieceIndex + 1 < activePieces.size(); pieceIndex += 2) + nextPieces.push_back( + spatial::SpatVAddOp::create(rewriter, loc, pieceType, activePieces[pieceIndex], activePieces[pieceIndex + 1]) + .getResult()); + if (activePieces.size() % 2 != 0) + nextPieces.push_back(activePieces.back()); + activePieces = std::move(nextPieces); + } + + return activePieces.front(); +} + +static Value createBatchedReductionCompute(Value partialPieces, + RankedTensorType partialPiecesType, + RankedTensorType outType, + RankedTensorType paddedOutType, + int64_t numBatches, + int64_t numKSlices, + PatternRewriter& rewriter, + Location loc) { + auto computeOp = createSpatCompute<1>( + rewriter, loc, TypeRange {outType}, {}, ValueRange {partialPieces}, [&](Value partialPiecesArg) { + const int64_t numOutRows = outType.getDimSize(1); + const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(2), crossbarSize.getValue()); + auto pieceType = RankedTensorType::get({numOutRows, static_cast(crossbarSize.getValue())}, + partialPiecesType.getElementType()); + auto outputSliceType = RankedTensorType::get({1, numOutRows, static_cast(crossbarSize.getValue())}, + partialPiecesType.getElementType()); + + Value outputInit = + tensor::EmptyOp::create(rewriter, loc, paddedOutType.getShape(), paddedOutType.getElementType()).getResult(); + Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); + Value cNumBatches = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numBatches); + Value cNumOutHSlices = + getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices); + + auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cNumBatches, c1, ValueRange {outputInit}); + rewriter.setInsertionPointToStart(batchLoop.getBody()); + Value batch = batchLoop.getInductionVar(); + Value batchAcc = batchLoop.getRegionIterArgs().front(); + + auto hLoop = scf::ForOp::create(rewriter, loc, c0, cNumOutHSlices, c1, ValueRange {batchAcc}); + rewriter.setInsertionPointToStart(hLoop.getBody()); + Value hSlice = hLoop.getInductionVar(); + Value outputAcc = hLoop.getRegionIterArgs().front(); + + Value reduced = reduceBatchedPartialPiecesForHSlice( + partialPiecesArg, batch, hSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, loc); + Value expandedReduced = tensor::ExpandShapeOp::create(rewriter, + loc, + outputSliceType, + reduced, + SmallVector { + {0, 1}, + {2} + }); + Value hOffset = + multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue()); + SmallVector outputOffsets {batch, rewriter.getIndexAttr(0), hOffset}; + SmallVector outputSizes { + rewriter.getIndexAttr(1), rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())}; + scf::YieldOp::create( + rewriter, + loc, + tensor::InsertSliceOp::create( + rewriter, loc, expandedReduced, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3)) + .getResult()); + + rewriter.setInsertionPointAfter(hLoop); + scf::YieldOp::create(rewriter, loc, hLoop.getResult(0)); + + rewriter.setInsertionPointAfter(batchLoop); + Value paddedOutput = batchLoop.getResult(0); + Value result = paddedOutput; + if (paddedOutType != outType) { + SmallVector outputOffsets { + rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; + SmallVector outputSizes {rewriter.getIndexAttr(numBatches), + rewriter.getIndexAttr(outType.getDimSize(1)), + rewriter.getIndexAttr(outType.getDimSize(2))}; + result = tensor::ExtractSliceOp::create( + rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, getUnitStrides(rewriter, 3)); + } + spatial::SpatYieldOp::create(rewriter, loc, result); + }); + return computeOp.getResult(0); +} + +struct MatMulShapeInfo { + RankedTensorType lhsType; + RankedTensorType rhsType; + RankedTensorType outType; + SmallVector batchShape; + int64_t lhsBatch; + int64_t rhsBatch; + int64_t batch; + int64_t m; + int64_t k; + int64_t n; +}; + +static FailureOr analyzeMatMulShape(ONNXMatMulOp matmulOp) { + auto lhsType = dyn_cast(matmulOp.getA().getType()); + auto rhsType = dyn_cast(matmulOp.getB().getType()); + auto outType = dyn_cast(matmulOp.getY().getType()); + if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape() + || !outType.hasStaticShape()) + return failure(); + if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2) + return failure(); + if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType)) + return failure(); + + SmallVector lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2); + SmallVector rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2); + auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape); + if (failed(batchShape)) + return failure(); + + const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape); + const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape); + const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape); + const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2); + const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1); + const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2); + const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1); + if (k != rhsK) + return failure(); + + if (outType.getRank() == 2) { + if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n) + return failure(); + } + else { + SmallVector outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2); + if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m + || outType.getDimSize(outType.getRank() - 1) != n) + return failure(); + } + + return MatMulShapeInfo {lhsType, rhsType, outType, *batchShape, lhsBatch, rhsBatch, batch, m, k, n}; } struct MatMulToGemm : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override { - auto lhsType = dyn_cast(matmulOp.getA().getType()); - auto rhsType = dyn_cast(matmulOp.getB().getType()); - auto outType = dyn_cast(matmulOp.getY().getType()); - if (!lhsType || !rhsType || !outType || !lhsType.hasStaticShape() || !rhsType.hasStaticShape() - || !outType.hasStaticShape()) + auto shapeInfo = analyzeMatMulShape(matmulOp); + if (failed(shapeInfo) || shapeInfo->outType.getRank() != 2) return failure(); - if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2) - return failure(); - if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType)) - return failure(); - - SmallVector lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2); - SmallVector rhsBatchShape(rhsType.getShape().begin(), rhsType.getShape().end() - 2); - auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape); - if (failed(batchShape)) - return failure(); - const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape); - const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape); - const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape); - - const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2); - const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1); - const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2); - const int64_t n = rhsType.getDimSize(rhsType.getRank() - 1); - if (k != rhsK) - return failure(); - - if (outType.getRank() == 2) { - if (batch != 1 || outType.getDimSize(0) != m || outType.getDimSize(1) != n) - return failure(); - } - else { - SmallVector outBatchShape(outType.getShape().begin(), outType.getShape().end() - 2); - if (!llvm::equal(outBatchShape, *batchShape) || outType.getDimSize(outType.getRank() - 2) != m - || outType.getDimSize(outType.getRank() - 1) != n) - return failure(); - } Location loc = matmulOp.getLoc(); bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB()); - Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc); - Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc); - int64_t lhsBatchForGemm = lhsBatch; - int64_t rhsBatchForGemm = rhsBatch; - int64_t gemmM = m; - int64_t gemmK = k; - int64_t gemmN = n; + Value lhs = collapseBatchDims(matmulOp.getA(), shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc); + Value rhs = collapseBatchDims(matmulOp.getB(), shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc); + int64_t lhsBatchForGemm = shapeInfo->lhsBatch; + int64_t rhsBatchForGemm = shapeInfo->rhsBatch; + int64_t gemmM = shapeInfo->m; + int64_t gemmK = shapeInfo->k; + int64_t gemmN = shapeInfo->n; if (useTransposedForm) { lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc); - lhsBatchForGemm = rhsBatch; + lhsBatchForGemm = shapeInfo->rhsBatch; rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc); - rhsBatchForGemm = lhsBatch; - gemmM = n; - gemmN = m; + rhsBatchForGemm = shapeInfo->lhsBatch; + gemmM = shapeInfo->n; + gemmN = shapeInfo->m; } - auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType()); - auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType()); + auto gemmType = RankedTensorType::get({gemmM, gemmN}, shapeInfo->outType.getElementType()); Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType()); - - if (outType.getRank() == 2) { - Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc); - Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc); - Value gemmResult = ONNXGemmOp::create(rewriter, - loc, - gemmType, - lhsMatrix, - rhsMatrix, - none, - rewriter.getF32FloatAttr(1.0f), - rewriter.getF32FloatAttr(1.0f), - rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)) - .getY(); - if (useTransposedForm) { - auto transposeCompute = - createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) { - Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0})); - spatial::SpatYieldOp::create(rewriter, loc, transposed); - }); - gemmResult = transposeCompute.getResult(0); - } - rewriter.replaceOp(matmulOp, gemmResult); - return success(); - } - - SmallVector batchResults; - batchResults.reserve(batch); - for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) { - Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc); - Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc); - Value gemmResult = ONNXGemmOp::create(rewriter, - loc, - gemmType, - lhsMatrix, - rhsMatrix, - none, - rewriter.getF32FloatAttr(1.0f), - rewriter.getF32FloatAttr(1.0f), - rewriter.getBoolAttr(false), - rewriter.getBoolAttr(false)) - .getY(); - auto batchResultCompute = - createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) { - Value resultMatrix = input; - if (useTransposedForm) { - resultMatrix = ONNXTransposeOp::create(rewriter, - loc, - RankedTensorType::get({m, n}, outType.getElementType()), - input, - rewriter.getI64ArrayAttr({1, 0})); - } - Value expanded = tensor::ExpandShapeOp::create(rewriter, - loc, - batchedOutType, - resultMatrix, - SmallVector { - {0, 1}, - {2} - }); - spatial::SpatYieldOp::create(rewriter, loc, expanded); + Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc); + Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc); + Value gemmResult = ONNXGemmOp::create(rewriter, + loc, + gemmType, + lhsMatrix, + rhsMatrix, + none, + rewriter.getF32FloatAttr(1.0f), + rewriter.getF32FloatAttr(1.0f), + rewriter.getBoolAttr(false), + rewriter.getBoolAttr(false)) + .getY(); + if (useTransposedForm) { + auto transposeCompute = + createSpatCompute<1>(rewriter, loc, TypeRange {shapeInfo->outType}, {}, gemmResult, [&](Value input) { + Value transposed = + ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, input, rewriter.getI64ArrayAttr({1, 0})); + spatial::SpatYieldOp::create(rewriter, loc, transposed); }); - batchResults.push_back(batchResultCompute.getResult(0)); + gemmResult = transposeCompute.getResult(0); + } + rewriter.replaceOp(matmulOp, gemmResult); + return success(); + } +}; + +struct MatMulBatchedToSpatialComputes : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override { + auto shapeInfo = analyzeMatMulShape(matmulOp); + if (failed(shapeInfo)) + return failure(); + if (shapeInfo->outType.getRank() == 2) + return failure(); + + Location loc = matmulOp.getLoc(); + bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB()); + + Value lhs = collapseBatchDims(matmulOp.getA(), shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc); + Value rhs = collapseBatchDims(matmulOp.getB(), shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc); + int64_t lhsBatchForGemm = shapeInfo->lhsBatch; + int64_t rhsBatchForGemm = shapeInfo->rhsBatch; + int64_t gemmM = shapeInfo->m; + int64_t gemmK = shapeInfo->k; + int64_t gemmN = shapeInfo->n; + if (useTransposedForm) { + lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc); + lhsBatchForGemm = shapeInfo->rhsBatch; + rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc); + rhsBatchForGemm = shapeInfo->lhsBatch; + gemmM = shapeInfo->n; + gemmN = shapeInfo->m; } - Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc); - result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc); + lhs = ensureBatchedTensor(lhs, lhsBatchForGemm, gemmM, gemmK, rewriter, loc); + rhs = ensureBatchedTensor(rhs, rhsBatchForGemm, gemmK, gemmN, rewriter, loc); + auto lhsBatchedType = cast(lhs.getType()); + auto rhsBatchedType = cast(rhs.getType()); + auto directOutType = RankedTensorType::get({shapeInfo->batch, gemmM, gemmN}, shapeInfo->outType.getElementType()); + + if (isCompileTimeComputable(rhs)) { + const int64_t numKSlices = ceilIntegerDivide(gemmK, crossbarSize.getValue()); + const int64_t numOutHSlices = ceilIntegerDivide(gemmN, crossbarSize.getValue()); + const int64_t paddedReductionSize = numKSlices * static_cast(crossbarSize.getValue()); + const int64_t paddedOutCols = numOutHSlices * static_cast(crossbarSize.getValue()); + auto paddedLhsType = RankedTensorType::get( + {lhsBatchForGemm, gemmM, paddedReductionSize}, lhsBatchedType.getElementType(), lhsBatchedType.getEncoding()); + auto paddedRhsType = RankedTensorType::get({shapeInfo->batch, paddedReductionSize, paddedOutCols}, + rhsBatchedType.getElementType(), + rhsBatchedType.getEncoding()); + auto paddedOutType = + RankedTensorType::get({shapeInfo->batch, gemmM, paddedOutCols}, shapeInfo->outType.getElementType()); + + auto paddedRhs = materializePaddedBatchedWeight(rhs, rhsBatchForGemm, shapeInfo->batch, paddedRhsType, rewriter); + if (succeeded(paddedRhs)) { + Value paddedLhs = createPaddedBatchedInputCompute(lhs, paddedLhsType, rewriter, loc); + const int64_t laneCount = shapeInfo->batch * gemmM * numKSlices * numOutHSlices; + auto partialPiecesType = RankedTensorType::get({laneCount, static_cast(crossbarSize.getValue())}, + shapeInfo->outType.getElementType()); + auto batchOp = createBatchedVmmBatch(paddedLhs, + *paddedRhs, + paddedLhsType, + lhsBatchForGemm, + paddedRhsType, + rhsBatchForGemm, + partialPiecesType, + gemmM, + numKSlices, + numOutHSlices, + rewriter, + loc); + Value result = createBatchedReductionCompute(batchOp.getResult(0), + partialPiecesType, + directOutType, + paddedOutType, + shapeInfo->batch, + numKSlices, + rewriter, + loc); + if (useTransposedForm) + result = transposeBatchedOutput( + result, + RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()), + rewriter, + loc); + result = expandBatchDims(result, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc); + rewriter.replaceOp(matmulOp, result); + return success(); + } + } + const int64_t laneCount = shapeInfo->batch * gemmM * gemmN; + auto scalarPiecesType = RankedTensorType::get({laneCount, 1}, shapeInfo->outType.getElementType()); + auto batchOp = createBatchedVvdmulBatch(lhs, + lhsBatchForGemm, + rhs, + rhsBatchForGemm, + lhsBatchedType, + rhsBatchedType, + scalarPiecesType, + directOutType, + false, + rewriter, + loc); + Value result = + createBatchedDynamicOutputCompute(batchOp.getResult(0), scalarPiecesType, directOutType, rewriter, loc); + if (useTransposedForm) + result = transposeBatchedOutput( + result, + RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()), + rewriter, + loc); + result = expandBatchDims(result, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc); rewriter.replaceOp(matmulOp, result); return success(); } @@ -296,7 +888,7 @@ struct MatMulToGemm : OpRewritePattern { } // namespace void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { - patterns.insert(ctx); + patterns.insert(ctx); } } // namespace onnx_mlir diff --git a/validation/operations/README.md b/validation/operations/README.md index 237d1c8..e724410 100644 --- a/validation/operations/README.md +++ b/validation/operations/README.md @@ -48,12 +48,16 @@ python3 validation/operations/gen_tests.py ## MatMul -| Test | Directory | A input | B tensor | Output | Notes | -|------------|---------------------|---------|----------|---------|------------------------------------| -| Basic | `matmul/basic` | [2,3] | [3,4] | [2,4] | Direct 2D MatMul rewrite path | -| Left constant | `matmul/left_constant` | [2,3] | [3,4] | [2,4] | Constant LHS transpose rewrite path | -| Dynamic | `matmul/dynamic` | [2,3] | [3,4] | [2,4] | Runtime matrix operands | -| Batched 3D | `matmul/batched_3d` | [2,2,3] | [2,3,4] | [2,2,4] | Matching-batch MatMul rewrite path | +| Test | Directory | A input | B tensor | Output | Notes | +|---------------------|----------------------------------|----------|----------|---------|-------------------------------------------------| +| Basic | `matmul/basic` | [2,3] | [3,4] | [2,4] | Direct 2D MatMul rewrite path | +| Left constant | `matmul/left_constant` | [2,3] | [3,4] | [2,4] | Constant LHS transpose rewrite path | +| Dynamic | `matmul/dynamic` | [2,3] | [3,4] | [2,4] | Runtime matrix operands | +| Batched 3D | `matmul/batched_3d` | [2,2,3] | [2,3,4] | [2,2,4] | Matching-batch direct batched lowering | +| Batched 3D dynamic | `matmul/batched_3d_dynamic` | [2,2,3] | [2,3,4] | [2,2,4] | Batched runtime operands | +| Batched left const | `matmul/batched_left_constant` | [2,2,3] | [2,3,4] | [2,2,4] | Batched constant-LHS transpose path | +| Batched RHS broadcast | `matmul/batched_rhs_broadcast` | [2,2,3] | [3,4] | [2,2,4] | Rank-2 RHS broadcast across batch | +| Batched LHS broadcast | `matmul/batched_lhs_broadcast` | [2,3] | [2,3,4] | [2,2,4] | Rank-2 LHS broadcast across batched RHS | ## Gemv diff --git a/validation/operations/conv/dynamic/conv_dynamic.onnx b/validation/operations/conv/dynamic/conv_dynamic.onnx index 2f63495..c21dcf3 100644 Binary files a/validation/operations/conv/dynamic/conv_dynamic.onnx and b/validation/operations/conv/dynamic/conv_dynamic.onnx differ diff --git a/validation/operations/gemm/dynamic/gemm_dynamic.onnx b/validation/operations/gemm/dynamic/gemm_dynamic.onnx index 917113d..f23e103 100644 Binary files a/validation/operations/gemm/dynamic/gemm_dynamic.onnx and b/validation/operations/gemm/dynamic/gemm_dynamic.onnx differ diff --git a/validation/operations/gemm/dynamic_alpha/gemm_dynamic_alpha.onnx b/validation/operations/gemm/dynamic_alpha/gemm_dynamic_alpha.onnx index 4decf30..2fdccb3 100644 Binary files a/validation/operations/gemm/dynamic_alpha/gemm_dynamic_alpha.onnx and b/validation/operations/gemm/dynamic_alpha/gemm_dynamic_alpha.onnx differ diff --git a/validation/operations/gemm/dynamic_beta/gemm_dynamic_beta.onnx b/validation/operations/gemm/dynamic_beta/gemm_dynamic_beta.onnx index 1ea67b7..716d64d 100644 Binary files a/validation/operations/gemm/dynamic_beta/gemm_dynamic_beta.onnx and b/validation/operations/gemm/dynamic_beta/gemm_dynamic_beta.onnx differ diff --git a/validation/operations/gemm/dynamic_bias/gemm_dynamic_bias.onnx b/validation/operations/gemm/dynamic_bias/gemm_dynamic_bias.onnx index 69d3a3a..5ffd977 100644 Binary files a/validation/operations/gemm/dynamic_bias/gemm_dynamic_bias.onnx and b/validation/operations/gemm/dynamic_bias/gemm_dynamic_bias.onnx differ diff --git a/validation/operations/gemm/dynamic_bias_alpha_beta/gemm_dynamic_bias_alpha_beta.onnx b/validation/operations/gemm/dynamic_bias_alpha_beta/gemm_dynamic_bias_alpha_beta.onnx index ca50913..f5f03f3 100644 Binary files a/validation/operations/gemm/dynamic_bias_alpha_beta/gemm_dynamic_bias_alpha_beta.onnx and b/validation/operations/gemm/dynamic_bias_alpha_beta/gemm_dynamic_bias_alpha_beta.onnx differ diff --git a/validation/operations/gemm/dynamic_transB/gemm_dynamic_transB.onnx b/validation/operations/gemm/dynamic_transB/gemm_dynamic_transB.onnx index 82109e1..8e9eb94 100644 Binary files a/validation/operations/gemm/dynamic_transB/gemm_dynamic_transB.onnx and b/validation/operations/gemm/dynamic_transB/gemm_dynamic_transB.onnx differ diff --git a/validation/operations/gemm/simple/gemm_simple.onnx b/validation/operations/gemm/simple/gemm_simple.onnx index 01b21b6..237810f 100644 Binary files a/validation/operations/gemm/simple/gemm_simple.onnx and b/validation/operations/gemm/simple/gemm_simple.onnx differ diff --git a/validation/operations/gen_tests.py b/validation/operations/gen_tests.py index 0ba74b0..9714652 100644 --- a/validation/operations/gen_tests.py +++ b/validation/operations/gen_tests.py @@ -421,6 +421,53 @@ def matmul_batched_3d(): save_model(model, "matmul/batched_3d", "matmul_batched_3d.onnx") +def matmul_batched_3d_dynamic(): + """Batched 3D MatMul with both operands provided at runtime.""" + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 2, 3]) + B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [2, 3, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4]) + node = helper.make_node("MatMul", ["A", "B"], ["Y"]) + graph = helper.make_graph([node], "matmul_batched_3d_dynamic", [A, B], [Y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "matmul/batched_3d_dynamic", "matmul_batched_3d_dynamic.onnx") + + +def matmul_batched_left_constant(): + """Batched 3D MatMul with constant LHS and runtime RHS.""" + rng = np.random.default_rng(70) + A = numpy_helper.from_array(rng.uniform(-1, 1, (2, 2, 3)).astype(np.float32), name="A") + B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [2, 3, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4]) + node = helper.make_node("MatMul", ["A", "B"], ["Y"]) + graph = helper.make_graph([node], "matmul_batched_left_constant", [B], [Y], initializer=[A]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "matmul/batched_left_constant", "matmul_batched_left_constant.onnx") + + +def matmul_batched_rhs_broadcast(): + """Batched 3D MatMul with 2D constant RHS broadcast across batch.""" + rng = np.random.default_rng(71) + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 2, 3]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4]) + B = numpy_helper.from_array(rng.uniform(-1, 1, (3, 4)).astype(np.float32), name="B") + node = helper.make_node("MatMul", ["A", "B"], ["Y"]) + graph = helper.make_graph([node], "matmul_batched_rhs_broadcast", [A], [Y], initializer=[B]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "matmul/batched_rhs_broadcast", "matmul_batched_rhs_broadcast.onnx") + + +def matmul_batched_lhs_broadcast(): + """Batched 3D MatMul with 2D runtime LHS broadcast across batched RHS.""" + rng = np.random.default_rng(72) + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2, 4]) + B = numpy_helper.from_array(rng.uniform(-1, 1, (2, 3, 4)).astype(np.float32), name="B") + node = helper.make_node("MatMul", ["A", "B"], ["Y"]) + graph = helper.make_graph([node], "matmul_batched_lhs_broadcast", [A], [Y], initializer=[B]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + save_model(model, "matmul/batched_lhs_broadcast", "matmul_batched_lhs_broadcast.onnx") + + # --------------------------------------------------------------------------- # Pooling tests # --------------------------------------------------------------------------- @@ -972,6 +1019,10 @@ if __name__ == "__main__": matmul_left_constant() matmul_dynamic() matmul_batched_3d() + matmul_batched_3d_dynamic() + matmul_batched_left_constant() + matmul_batched_rhs_broadcast() + matmul_batched_lhs_broadcast() print("\nGenerating Pooling tests:") maxpool_basic() diff --git a/validation/operations/matmul/batched_3d_dynamic/matmul_batched_3d_dynamic.onnx b/validation/operations/matmul/batched_3d_dynamic/matmul_batched_3d_dynamic.onnx new file mode 100644 index 0000000..c8037ed Binary files /dev/null and b/validation/operations/matmul/batched_3d_dynamic/matmul_batched_3d_dynamic.onnx differ diff --git a/validation/operations/matmul/batched_left_constant/matmul_batched_left_constant.onnx b/validation/operations/matmul/batched_left_constant/matmul_batched_left_constant.onnx new file mode 100644 index 0000000..91921f7 Binary files /dev/null and b/validation/operations/matmul/batched_left_constant/matmul_batched_left_constant.onnx differ diff --git a/validation/operations/matmul/batched_lhs_broadcast/matmul_batched_lhs_broadcast.onnx b/validation/operations/matmul/batched_lhs_broadcast/matmul_batched_lhs_broadcast.onnx new file mode 100644 index 0000000..b1d7810 Binary files /dev/null and b/validation/operations/matmul/batched_lhs_broadcast/matmul_batched_lhs_broadcast.onnx differ diff --git a/validation/operations/matmul/batched_rhs_broadcast/matmul_batched_rhs_broadcast.onnx b/validation/operations/matmul/batched_rhs_broadcast/matmul_batched_rhs_broadcast.onnx new file mode 100644 index 0000000..7981e76 Binary files /dev/null and b/validation/operations/matmul/batched_rhs_broadcast/matmul_batched_rhs_broadcast.onnx differ diff --git a/validation/operations/matmul/dynamic/matmul_dynamic.onnx b/validation/operations/matmul/dynamic/matmul_dynamic.onnx index f7fe1fd..30947ff 100644 Binary files a/validation/operations/matmul/dynamic/matmul_dynamic.onnx and b/validation/operations/matmul/dynamic/matmul_dynamic.onnx differ diff --git a/validation/operations/matmul/left_constant/matmul_left_constant.onnx b/validation/operations/matmul/left_constant/matmul_left_constant.onnx index 0cf483a..7f727e5 100644 Binary files a/validation/operations/matmul/left_constant/matmul_left_constant.onnx and b/validation/operations/matmul/left_constant/matmul_left_constant.onnx differ