#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Location.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static FailureOr materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewriter& rewriter, Location loc) { if (factor == 1.0f) return value; auto constantOp = value.getDefiningOp(); if (!constantOp) return failure(); auto denseAttr = dyn_cast(constantOp.getValue()); if (!denseAttr) return failure(); SmallVector scaledValues; scaledValues.reserve(denseAttr.getNumElements()); APFloat scale(factor); bool hadFailure = false; for (const APFloat& originalValue : denseAttr.getValues()) { APFloat scaledValue(originalValue); if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven)) hadFailure = true; scaledValues.push_back(std::move(scaledValue)); } if (hadFailure) return failure(); auto scaledAttr = DenseFPElementsAttr::get(cast(denseAttr.getType()), scaledValues); return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult(); } static Value transposeForSpatial(Value value, RankedTensorType resultType, ArrayRef permutation, ConversionPatternRewriter& rewriter, Location loc) { if (isHostFoldableValue(value)) return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation)); auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) { Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)); spatial::SpatYieldOp::create(rewriter, loc, transposed); }); return computeOp.getResult(0); } static Value expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) { if (isHostFoldableValue(value)) return tensor::ExpandShapeOp::create(rewriter, loc, resultType, value, SmallVector { {0, 1} }); auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) { Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, resultType, input, SmallVector { {0, 1} }); spatial::SpatYieldOp::create(rewriter, loc, expanded); }); return computeOp.getResult(0); } struct GemmToManyGemv : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const override; }; struct GemvToSpatialCompute : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const override; }; struct GemmToSpatialComputeBatch : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const override; }; static SmallVector materializeBatchRowSlices(Value matrix, RankedTensorType matrixType, ConversionPatternRewriter& rewriter, Location loc) { const int64_t numRows = matrixType.getDimSize(0); auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType()); SmallVector resultTypes(static_cast(numRows), rowType); if (isHostFoldableValue(matrix)) { auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix); return SmallVector(extractRowsOp->result_begin(), extractRowsOp->result_end()); } auto buildRowSlices = [&](Value matrixArg) { auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg); return SmallVector(extractRowsOp->result_begin(), extractRowsOp->result_end()); }; auto cloneBatchInputChainIntoSliceCompute = [&](Value rootInput, SmallVector chainOps, Value rootValue) -> SmallVector { auto sliceCompute = createSpatCompute<1>(rewriter, loc, TypeRange(resultTypes), {}, ValueRange {rootInput}, [&](Value input) { Value transformedMatrix = input; if (!chainOps.empty()) { IRMapping mapper; mapper.map(rootValue, input); for (Operation* chainOp : chainOps) rewriter.clone(*chainOp, mapper); transformedMatrix = cast(mapper.lookup(matrix)); } spatial::SpatYieldOp::create(rewriter, loc, buildRowSlices(transformedMatrix)); }); SmallVector rowSlices(sliceCompute->result_begin(), sliceCompute->result_end()); return rowSlices; }; SmallVector chainOps; Value rootValue = matrix; while (Operation* definingOp = rootValue.getDefiningOp()) { if (auto rootCompute = dyn_cast(definingOp)) { SmallVector reversedChainOps(chainOps.rbegin(), chainOps.rend()); return cloneBatchInputChainIntoSliceCompute( rootCompute.getResult(cast(rootValue).getResultNumber()), reversedChainOps, rootValue); } if (definingOp->getNumOperands() != 1) break; if (!isa(definingOp)) break; chainOps.push_back(definingOp); rootValue = definingOp->getOperand(0); } SmallVector reversedChainOps(chainOps.rbegin(), chainOps.rend()); return cloneBatchInputChainIntoSliceCompute(rootValue, reversedChainOps, rootValue); } } // namespace LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const { Location loc = gemmOp.getLoc(); Value a = gemmOpAdaptor.getA(); Value b = gemmOpAdaptor.getB(); Value c = gemmOpAdaptor.getC(); if (gemmOpAdaptor.getTransA()) { gemmOp.emitOpError("requires transA=false before Gemm row decomposition"); return failure(); } bool hasC = !isa(c.getDefiningOp()); auto aType = cast(a.getType()); auto outType = cast(gemmOp.getY().getType()); if (!aType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A"); return failure(); } if (!outType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result"); return failure(); } const int64_t numOutRows = aType.getDimSize(0); // Only decompose when there are multiple rows to split if (numOutRows <= 1) return failure(); auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc); if (failed(scaledB)) return failure(); b = *scaledB; RankedTensorType cType = nullptr; bool cHasNumOutRows = false; if (hasC) { auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc); if (failed(scaledC)) return failure(); c = *scaledC; cType = cast(c.getType()); // Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling if (cType.getRank() == 1) { auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); c = expandRankOneBias(c, expandedType, rewriter, loc); cType = expandedType; } if (!cType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias"); return failure(); } if (cType.getRank() != 2) { pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2}); return failure(); } cHasNumOutRows = cType.getDimSize(0) == numOutRows; } auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType()); SmallVector aSlices = materializeBatchRowSlices(a, aType, rewriter, loc); SmallVector cSlices; if (hasC && cHasNumOutRows) cSlices = materializeBatchRowSlices(c, cType, rewriter, loc); SmallVector gemvOps; gemvOps.reserve(static_cast(numOutRows)); for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) { Value cSlice = c; if (hasC) { if (cHasNumOutRows) cSlice = cSlices[static_cast(rowIdx)]; else if (!isVectorShape(getTensorShape(c))) { gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows"); return failure(); } } auto gemvOp = ONNXGemmOp::create(rewriter, loc, outRowType, aSlices[static_cast(rowIdx)], b, cSlice, rewriter.getF32FloatAttr(1.0f), rewriter.getF32FloatAttr(1.0f), gemmOp.getTransAAttr(), gemmOp.getTransBAttr()); gemvOps.push_back(gemvOp.getY()); } auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) { spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, gemvOpsArgs)); }); rewriter.replaceOp(gemmOp, concatComputeOp); return success(); } LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const { Location gemmLoc = gemmOp.getLoc(); Value a = gemmOpAdaptor.getA(); Value b = gemmOpAdaptor.getB(); Value c = gemmOpAdaptor.getC(); Value out = gemmOp.getY(); float alpha = gemmOpAdaptor.getAlpha().convertToFloat(); float beta = gemmOpAdaptor.getBeta().convertToFloat(); bool transA = gemmOpAdaptor.getTransA(); bool transB = gemmOpAdaptor.getTransB(); auto aType = cast(a.getType()); auto bType = cast(b.getType()); auto outType = cast(out.getType()); RankedTensorType cType = nullptr; bool hasC = !isa(c.getDefiningOp()); if (hasC) { cType = cast(c.getType()); // Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling if (cType.getRank() == 1) { auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); c = expandRankOneBias(c, expandedType, rewriter, gemmLoc); cType = expandedType; } if (!cType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias"); return failure(); } if (cType.getRank() != 2) { pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2}); return failure(); } } if (!aType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A"); return failure(); } if (!bType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B"); return failure(); } if (!outType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result"); return failure(); } if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape()))) // Not a gemv return failure(); if (transA) { auto aShape = aType.getShape(); auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType()); a = transposeForSpatial(a, transposedType, {1, 0}, rewriter, gemmLoc); aType = cast(a.getType()); } if (transB) { auto bShape = bType.getShape(); auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType()); b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, gemmLoc); bType = cast(b.getType()); } if (alpha != 1.0f) { auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc); if (failed(scaledB)) return failure(); b = *scaledB; bType = cast(b.getType()); alpha = 1.0f; } if (hasC && beta != 1.0f) { auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc); if (failed(scaledC)) return failure(); c = *scaledC; cType = cast(c.getType()); beta = 1.0f; } auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue()); auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue()); auto bNumVSlices = aNumHSlices; auto cNumHSlices = bNumHSlices; auto cLastHSliceSize = bLastHSliceSize; auto outNumHSlices = cNumHSlices; auto outLastHSliceSize = cLastHSliceSize; const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue()); DenseMap> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc); DenseMap>> bTiles = tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc); SmallVector cHSlices; if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1) c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc); if (hasC) cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc); RankedTensorType outHSliceType = RankedTensorType::get({1, static_cast(crossbarSize)}, outType.getElementType()); RankedTensorType outLastHSliceType = RankedTensorType::get({1, static_cast(bLastHSliceSize)}, outType.getElementType()); SmallVector outHSlices; outHSlices.reserve(outNumHSlices); for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) { RankedTensorType currOutHSliceType = outHSliceType; if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0) currOutHSliceType = outLastHSliceType; SmallVector partialResults; partialResults.reserve(coresPerVSlice); for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) { SmallVector weights; weights.reserve(aHSlices[coreId].size()); for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++) weights.push_back(bTiles[outSliceId][coreId][aSliceId]); auto computeOp = spatial::SpatCompute::create(rewriter, gemmLoc, TypeRange {currOutHSliceType}, weights, aHSlices[coreId]); SmallVector blockArgTypes; SmallVector blockArgLocs; blockArgTypes.reserve(weights.size() + aHSlices[coreId].size()); blockArgLocs.reserve(weights.size() + aHSlices[coreId].size()); for (Value weight : weights) { blockArgTypes.push_back(weight.getType()); blockArgLocs.push_back(gemmLoc); } for (Value input : aHSlices[coreId]) { blockArgTypes.push_back(input.getType()); blockArgLocs.push_back(gemmLoc); } Block* body = rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); rewriter.setInsertionPointToEnd(body); SmallVector vmmOutputs; vmmOutputs.reserve(aHSlices[coreId].size()); for (auto aHSliceId : llvm::seq(0, aHSlices[coreId].size())) { auto weightArg = computeOp.getWeightArgument(aHSliceId); auto inputArg = computeOp.getInputArgument(aHSliceId); if (!weightArg || !inputArg) return failure(); vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, *weightArg, *inputArg)); } if (vmmOutputs.empty()) { gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs"); return failure(); } Value partialVmmSum = sumTensors(vmmOutputs, rewriter); spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum); rewriter.setInsertionPointAfter(computeOp); partialResults.push_back(computeOp->getResult(0)); } if (hasC) { Value cHSlice = cHSlices[outSliceId]; partialResults.push_back(cHSlice); } auto reduceComputeOp = createSpatCompute(rewriter, gemmLoc, currOutHSliceType, {}, partialResults, [&](ValueRange blockArgs) { SmallVector values(blockArgs.begin(), blockArgs.end()); Value outHSlice = sumTensors(values, rewriter); spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice); }); outHSlices.push_back(reduceComputeOp.getResult(0)); } auto concatComputeOp = createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) { spatial::SpatYieldOp::create(rewriter, gemmLoc, createSpatConcat(rewriter, gemmLoc, /*axis=*/1, blockArgs)); }); rewriter.replaceOp(gemmOp, concatComputeOp); return success(); } LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const { Location loc = gemmOp.getLoc(); Value a = gemmOpAdaptor.getA(); Value b = gemmOpAdaptor.getB(); Value c = gemmOpAdaptor.getC(); if (gemmOpAdaptor.getTransA()) { gemmOp.emitOpError("requires transA=false before batch Gemm lowering"); return failure(); } bool hasC = !isa(c.getDefiningOp()); auto aType = cast(a.getType()); auto bType = cast(b.getType()); auto outType = cast(gemmOp.getY().getType()); if (!aType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A"); return failure(); } if (!bType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B"); return failure(); } if (!outType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result"); return failure(); } const int64_t numOutRows = aType.getDimSize(0); if (numOutRows <= 1) return failure(); // Only handle the single-tile case: K <= crossbarSize and N <= crossbarSize if (aType.getDimSize(1) > static_cast(crossbarSize.getValue()) || outType.getDimSize(1) > static_cast(crossbarSize.getValue())) return failure(); auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc); if (failed(scaledB)) return failure(); b = *scaledB; bType = cast(b.getType()); if (gemmOpAdaptor.getTransB()) { auto bShape = bType.getShape(); auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType()); b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc); bType = cast(b.getType()); } (void) bType; Value sharedBias; if (hasC) { auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc); if (failed(scaledC)) return failure(); c = *scaledC; auto cType = cast(c.getType()); if (cType.getRank() == 1) { auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); c = expandRankOneBias(c, expandedType, rewriter, loc); cType = cast(c.getType()); } if (!cType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias"); return failure(); } if (cType.getRank() != 2) { pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2}); return failure(); } // Row-specific bias can't share a single template body; fall through to GemmToManyGemv if (cType.getDimSize(0) == numOutRows && numOutRows > 1) return failure(); if (cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1) c = broadcastToVector(c, outType.getDimSize(1), rewriter, loc); sharedBias = c; } auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType()); auto aRowType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType()); auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, TypeRange {outType}, rewriter.getI32IntegerAttr(static_cast(numOutRows)), ValueRange {b}, ValueRange {a}); SmallVector blockArgTypes {rewriter.getIndexType(), bType, aType, outType}; SmallVector blockArgLocs(4, loc); Block* body = rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); rewriter.setInsertionPointToEnd(body); auto lane = batchOp.getLaneArgument(); auto weight = batchOp.getWeightArgument(0); auto packedInput = batchOp.getInputArgument(0); auto packedOutput = batchOp.getOutputArgument(0); if (!lane || !weight || !packedInput || !packedOutput) return failure(); SmallVector inputOffsets {*lane, rewriter.getIndexAttr(0)}; SmallVector inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))}; SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value row = tensor::ExtractSliceOp::create(rewriter, loc, aRowType, *packedInput, inputOffsets, inputSizes, unitStrides) .getResult(); Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, *weight, row).getResult(); Value laneResult = vmmResult; if (sharedBias) laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult(); auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc); rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); SmallVector outputOffsets {*lane, rewriter.getIndexAttr(0)}; SmallVector outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))}; tensor::ParallelInsertSliceOp::create( rewriter, loc, laneResult, *packedOutput, outputOffsets, outputSizes, unitStrides); rewriter.setInsertionPointAfter(batchOp); rewriter.replaceOp(gemmOp, batchOp.getResults()); return success(); } void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx, PatternBenefit(2)); patterns.insert(ctx); patterns.insert(ctx); } } // namespace onnx_mlir