#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Location.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static FailureOr materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewriter& rewriter, Location loc) { if (factor == 1.0f) return value; auto constantOp = value.getDefiningOp(); if (!constantOp) return failure(); auto denseAttr = dyn_cast(constantOp.getValue()); if (!denseAttr) return failure(); SmallVector scaledValues; scaledValues.reserve(denseAttr.getNumElements()); APFloat scale(factor); bool hadFailure = false; for (const APFloat& originalValue : denseAttr.getValues()) { APFloat scaledValue(originalValue); if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven)) hadFailure = true; scaledValues.push_back(std::move(scaledValue)); } if (hadFailure) return failure(); auto scaledAttr = DenseFPElementsAttr::get(cast(denseAttr.getType()), scaledValues); return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult(); } struct GemmToManyGemv : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const override; }; struct GemvToSpatialCompute : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const override; }; struct GemmToSpatialComputeBatch : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const override; }; static SmallVector materializeBatchRowSlices(Value matrix, RankedTensorType matrixType, ConversionPatternRewriter& rewriter, Location loc) { const int64_t numRows = matrixType.getDimSize(0); auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType()); SmallVector resultTypes(static_cast(numRows), rowType); auto buildRowSlices = [&](Value matrixArg) { auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg); return SmallVector(extractRowsOp->result_begin(), extractRowsOp->result_end()); }; auto cloneBatchInputChainIntoSliceCompute = [&](Value rootInput, SmallVector chainOps, Value rootValue) -> SmallVector { auto sliceCompute = createSpatCompute<1>(rewriter, loc, TypeRange(resultTypes), {}, ValueRange {rootInput}, [&](Value input) { Value transformedMatrix = input; if (!chainOps.empty()) { IRMapping mapper; mapper.map(rootValue, input); for (Operation* chainOp : chainOps) rewriter.clone(*chainOp, mapper); transformedMatrix = cast(mapper.lookup(matrix)); } spatial::SpatYieldOp::create(rewriter, loc, buildRowSlices(transformedMatrix)); }); SmallVector rowSlices(sliceCompute->result_begin(), sliceCompute->result_end()); return rowSlices; }; SmallVector chainOps; Value rootValue = matrix; while (Operation* definingOp = rootValue.getDefiningOp()) { if (auto rootCompute = dyn_cast(definingOp)) { SmallVector reversedChainOps(chainOps.rbegin(), chainOps.rend()); return cloneBatchInputChainIntoSliceCompute( rootCompute.getResult(cast(rootValue).getResultNumber()), reversedChainOps, rootValue); } if (definingOp->getNumOperands() != 1) break; if (!isa(definingOp)) break; chainOps.push_back(definingOp); rootValue = definingOp->getOperand(0); } return buildRowSlices(matrix); } } // namespace LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const { Location loc = gemmOp.getLoc(); Value a = gemmOpAdaptor.getA(); Value b = gemmOpAdaptor.getB(); Value c = gemmOpAdaptor.getC(); assert("A should have been transposed already" && !gemmOpAdaptor.getTransA()); bool hasC = !isa(c.getDefiningOp()); auto aType = cast(a.getType()); auto outType = cast(gemmOp.getY().getType()); assert("Only support static shapes" && aType.hasStaticShape() && outType.hasStaticShape()); const int64_t numOutRows = aType.getDimSize(0); // Only decompose when there are multiple rows to split if (numOutRows <= 1) return failure(); auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc); if (failed(scaledB)) return failure(); b = *scaledB; RankedTensorType cType = nullptr; bool cHasNumOutRows = false; if (hasC) { auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc); if (failed(scaledC)) return failure(); c = *scaledC; cType = cast(c.getType()); // Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling if (cType.getRank() == 1) { auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); c = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, c, SmallVector { {0, 1} }); cType = expandedType; } assert("Only support rank 2 tensor for C" && cType.getRank() == 2); cHasNumOutRows = cType.getDimSize(0) == numOutRows; } auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType()); SmallVector gemvOps; gemvOps.reserve(numOutRows); for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) { SmallVector offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)}; SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType()); auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult(); Value cSlice = c; if (hasC) { if (cHasNumOutRows) { SmallVector offsets = {rewriter.getIndexAttr(rowIdx), rewriter.getIndexAttr(0)}; SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType()); cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult(); } else assert("C should be a vector" && isVectorShape(getTensorShape(c))); } auto gemvOp = ONNXGemmOp::create(rewriter, loc, outRowType, aSlice, b, cSlice, rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f), gemmOp.getTransAAttr(), gemmOp.getTransBAttr()); gemvOps.push_back(gemvOp.getY()); } auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) { spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, gemvOpsArgs)); }); rewriter.replaceOp(gemmOp, concatComputeOp); return success(); } LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const { Location gemmLoc = gemmOp.getLoc(); Value a = gemmOpAdaptor.getA(); Value b = gemmOpAdaptor.getB(); Value c = gemmOpAdaptor.getC(); Value out = gemmOp.getY(); float alpha = gemmOpAdaptor.getAlpha().convertToFloat(); float beta = gemmOpAdaptor.getBeta().convertToFloat(); bool transA = gemmOpAdaptor.getTransA(); bool transB = gemmOpAdaptor.getTransB(); auto aType = cast(a.getType()); auto bType = cast(b.getType()); auto outType = cast(out.getType()); RankedTensorType cType = nullptr; bool hasC = !isa(c.getDefiningOp()); if (hasC) { cType = cast(c.getType()); // Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling if (cType.getRank() == 1) { auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); c = tensor::ExpandShapeOp::create(rewriter, gemmLoc, expandedType, c, SmallVector { {0, 1} }); cType = expandedType; } assert("Only support rank 2 tensor for C" && cType.getRank() == 2); } assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() && (!hasC || cType.hasStaticShape()) && outType.hasStaticShape()); if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape()))) // Not a gemv return failure(); if (transA) { auto aShape = aType.getShape(); auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType()); a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})); } if (transB) { auto bShape = bType.getShape(); auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType()); b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); bType = cast(b.getType()); } if (alpha != 1.0f) { auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc); if (failed(scaledB)) return failure(); b = *scaledB; bType = cast(b.getType()); alpha = 1.0f; } if (hasC && beta != 1.0f) { auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc); if (failed(scaledC)) return failure(); c = *scaledC; cType = cast(c.getType()); beta = 1.0f; } auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue()); auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue()); auto bNumVSlices = aNumHSlices; auto bLastVSliceSize = aLastHSliceSize; auto cNumHSlices = bNumHSlices; auto cLastHSliceSize = bLastHSliceSize; auto outNumHSlices = cNumHSlices; auto outLastHSliceSize = cLastHSliceSize; const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue()); DenseMap> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc); DenseMap>> bTiles = tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc); SmallVector cHSlices; if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1) c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc); if (hasC) cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc); RankedTensorType outHSliceType = RankedTensorType::get({1, static_cast(crossbarSize)}, outType.getElementType()); RankedTensorType outLastHSliceType = RankedTensorType::get({1, static_cast(bLastHSliceSize)}, outType.getElementType()); SmallVector outHSlices; outHSlices.reserve(outNumHSlices); for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) { RankedTensorType currOutHSliceType = outHSliceType; if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0) currOutHSliceType = outLastHSliceType; SmallVector partialResults; partialResults.reserve(coresPerVSlice); for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) { SmallVector weights; weights.reserve(aHSlices[coreId].size()); for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++) weights.push_back(bTiles[outSliceId][coreId][aSliceId]); auto computeOp = createSpatCompute( rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId], [&](ValueRange aHSlicesArgs) { SmallVector vmmOutputs; vmmOutputs.reserve(aHSlicesArgs.size()); for (auto [aHSliceId, computeArg] : llvm::enumerate(aHSlicesArgs)) vmmOutputs.push_back( spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArg)); assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty"); Value partialVmmSum = sumTensors(vmmOutputs, rewriter); spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum); }); partialResults.push_back(computeOp.getResult(0)); } if (hasC) { Value cHSlice = cHSlices[outSliceId]; partialResults.push_back(cHSlice); } auto reduceComputeOp = createSpatCompute(rewriter, gemmLoc, currOutHSliceType, {}, partialResults, [&](ValueRange blockArgs) { SmallVector values(blockArgs.begin(), blockArgs.end()); Value outHSlice = sumTensors(values, rewriter); spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice); }); outHSlices.push_back(reduceComputeOp.getResult(0)); } auto concatComputeOp = createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) { spatial::SpatYieldOp::create(rewriter, gemmLoc, createSpatConcat(rewriter, gemmLoc, /*axis=*/1, blockArgs)); }); rewriter.replaceOp(gemmOp, concatComputeOp); return success(); } LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const { Location loc = gemmOp.getLoc(); Value a = gemmOpAdaptor.getA(); Value b = gemmOpAdaptor.getB(); Value c = gemmOpAdaptor.getC(); assert("A should have been transposed already" && !gemmOpAdaptor.getTransA()); bool hasC = !isa(c.getDefiningOp()); auto aType = cast(a.getType()); auto bType = cast(b.getType()); auto outType = cast(gemmOp.getY().getType()); assert("Only support static shapes" && aType.hasStaticShape() && bType.hasStaticShape() && outType.hasStaticShape()); const int64_t numOutRows = aType.getDimSize(0); if (numOutRows <= 1) return failure(); // Only handle the single-tile case: K <= crossbarSize and N <= crossbarSize if (aType.getDimSize(1) > static_cast(crossbarSize.getValue()) || outType.getDimSize(1) > static_cast(crossbarSize.getValue())) return failure(); auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc); if (failed(scaledB)) return failure(); b = *scaledB; bType = cast(b.getType()); if (gemmOpAdaptor.getTransB()) { auto bShape = bType.getShape(); auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType()); b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); bType = cast(b.getType()); } (void) bType; Value sharedBias; if (hasC) { auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc); if (failed(scaledC)) return failure(); c = *scaledC; auto cType = cast(c.getType()); if (cType.getRank() == 1) { auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); c = tensor::ExpandShapeOp::create(rewriter, loc, expandedType, c, SmallVector { {0, 1} }); cType = cast(c.getType()); } assert("Only support rank 2 tensor for C" && cType.getRank() == 2); // Row-specific bias can't share a single template body; fall through to GemmToManyGemv if (cType.getDimSize(0) == numOutRows && numOutRows > 1) return failure(); if (cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1) c = broadcastToVector(c, outType.getDimSize(1), rewriter, loc); sharedBias = c; } SmallVector aSlices = materializeBatchRowSlices(a, aType, rewriter, loc); auto aSliceType = cast(aSlices.front().getType()); auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType()); SmallVector resultTypes(static_cast(numOutRows), outRowType); SmallVector weights(static_cast(numOutRows), b); auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, TypeRange(resultTypes), rewriter.getI32IntegerAttr(static_cast(numOutRows)), ValueRange(weights), ValueRange(aSlices)); Block* body = rewriter.createBlock( &batchOp.getBody(), batchOp.getBody().end(), TypeRange {aSliceType}, SmallVector(1, loc)); rewriter.setInsertionPointToEnd(body); Value vmmResult = spatial::SpatWeightedVMMOp::create(rewriter, loc, outRowType, 0, body->getArgument(0)).getResult(); Value laneResult = vmmResult; if (sharedBias) laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult(); spatial::SpatYieldOp::create(rewriter, loc, laneResult); rewriter.setInsertionPointAfter(batchOp); SmallVector laneResults(batchOp->result_begin(), batchOp->result_end()); auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, laneResults, [&](ValueRange args) { spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, args)); }); rewriter.replaceOp(gemmOp, concatComputeOp); return success(); } void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx, PatternBenefit(2)); patterns.insert(ctx); patterns.insert(ctx); } } // namespace onnx_mlir