#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" #include #include #include "Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { namespace { static FailureOr materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewriter& rewriter, Location loc) { if (factor == 1.0f) return value; auto denseAttr = dyn_cast_or_null(getHostConstDenseElementsAttr(value)); if (!denseAttr) return failure(); SmallVector scaledValues; scaledValues.reserve(denseAttr.getNumElements()); APFloat scale(factor); bool hadFailure = false; for (const APFloat& originalValue : denseAttr.getValues()) { APFloat scaledValue(originalValue); if (scaledValue.multiply(scale, APFloat::rmNearestTiesToEven)) hadFailure = true; scaledValues.push_back(std::move(scaledValue)); } if (hadFailure) return failure(); auto scaledAttr = DenseFPElementsAttr::get(cast(denseAttr.getType()), scaledValues); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scaledAttr, denseAttr.getType()); } static Value createGemmBatchKOffset( Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) { if (numKSlices == 1) return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); return createOrFoldAffineApply(rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane}, rewriter.getInsertionBlock()->getParentOp()); } static Value createGemmBatchHOffset(Value lane, int64_t numOutRows, int64_t numKSlices, int64_t numOutHSlices, ConversionPatternRewriter& rewriter, Location loc) { if (numOutHSlices == 1) return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane}, rewriter.getInsertionBlock()->getParentOp()); } static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) { auto sourceType = cast(value.getType()); SmallVector lowPads(sourceType.getRank(), rewriter.getIndexAttr(0)); SmallVector highPads; highPads.reserve(sourceType.getRank()); for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape())) highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim)); auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads); auto* padBlock = new Block(); for (int64_t i = 0; i < sourceType.getRank(); ++i) padBlock->addArgument(rewriter.getIndexType(), loc); padOp.getRegion().push_back(padBlock); rewriter.setInsertionPointToStart(padBlock); auto zero = getOrCreateConstant( rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType()); tensor::YieldOp::create(rewriter, loc, zero); rewriter.setInsertionPointAfter(padOp); return padOp.getResult(); } static FailureOr materializePaddedConstantMatrix(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) { auto sourceType = cast(value.getType()); if (sourceType == resultType) return value; auto denseAttr = getHostConstDenseElementsAttr(value); if (!denseAttr) return failure(); auto denseType = dyn_cast(denseAttr.getType()); if (!denseType || denseType.getRank() != 2 || !denseType.hasStaticShape()) return failure(); ArrayRef sourceShape = denseType.getShape(); ArrayRef resultShape = resultType.getShape(); SmallVector sourceValues(denseAttr.getValues()); Attribute zero = rewriter.getZeroAttr(resultType.getElementType()); SmallVector resultValues(resultType.getNumElements(), zero); for (int64_t row = 0; row < sourceShape[0]; ++row) for (int64_t col = 0; col < sourceShape[1]; ++col) resultValues[row * resultShape[1] + col] = sourceValues[row * sourceShape[1] + col]; auto resultAttr = DenseElementsAttr::get(resultType, resultValues); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType); } static FailureOr materializePaddedBroadcastedConstantTensor(Value value, RankedTensorType resultType, int64_t unpaddedColumns, ConversionPatternRewriter& rewriter, Location loc) { auto denseAttr = getHostConstDenseElementsAttr(value); if (!denseAttr) return failure(); auto sourceType = dyn_cast(denseAttr.getType()); if (!sourceType || !sourceType.hasStaticShape() || sourceType.getRank() > resultType.getRank()) return failure(); ArrayRef sourceShape = sourceType.getShape(); ArrayRef resultShape = resultType.getShape(); SmallVector unpaddedResultShape(resultShape.begin(), resultShape.end()); unpaddedResultShape.back() = unpaddedColumns; const int64_t rankOffset = static_cast(resultShape.size() - sourceShape.size()); for (int64_t resultIndex = 0; resultIndex < static_cast(resultShape.size()); ++resultIndex) { const int64_t sourceIndex = resultIndex - rankOffset; if (sourceIndex < 0) continue; const int64_t sourceDim = sourceShape[sourceIndex]; const int64_t resultDim = unpaddedResultShape[resultIndex]; if (sourceDim != 1 && sourceDim != resultDim) return failure(); } SmallVector sourceValues(denseAttr.getValues()); SmallVector sourceStrides = computeRowMajorStrides(sourceShape); SmallVector resultStrides = computeRowMajorStrides(resultShape); Attribute zero = rewriter.getZeroAttr(resultType.getElementType()); SmallVector resultValues; resultValues.reserve(resultType.getNumElements()); for (int64_t flatIndex = 0; flatIndex < resultType.getNumElements(); ++flatIndex) { int64_t remaining = flatIndex; SmallVector resultIndices(resultShape.size(), 0); for (int64_t dim = 0; dim < static_cast(resultShape.size()); ++dim) { resultIndices[dim] = resultStrides.empty() ? 0 : remaining / resultStrides[dim]; remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim]; } if (resultIndices.back() >= unpaddedColumns) { resultValues.push_back(zero); continue; } int64_t sourceFlatIndex = 0; for (int64_t resultIndex = 0; resultIndex < static_cast(resultShape.size()); ++resultIndex) { const int64_t sourceIndex = resultIndex - rankOffset; if (sourceIndex < 0) continue; const int64_t sourceDim = sourceShape[sourceIndex]; const int64_t mappedIndex = sourceDim == 1 ? 0 : resultIndices[resultIndex]; sourceFlatIndex += mappedIndex * sourceStrides[sourceIndex]; } resultValues.push_back(sourceValues[sourceFlatIndex]); } auto resultAttr = DenseElementsAttr::get(resultType, resultValues); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType); } static FailureOr prepareBias(Value c, RankedTensorType outType, RankedTensorType paddedOutType, ConversionPatternRewriter& rewriter, Location loc) { auto cType = cast(c.getType()); if (!cType.hasStaticShape()) return failure(); if (isCompileTimeComputable(c)) return materializePaddedBroadcastedConstantTensor(c, paddedOutType, outType.getDimSize(1), rewriter, loc); if (cType != outType) return failure(); return c; } static Value extractATile( Value a, Value row, Value kOffset, RankedTensorType aTileType, ConversionPatternRewriter& rewriter, Location loc) { SmallVector offsets {row, kOffset}; SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())}; SmallVector strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; return tensor::ExtractSliceOp::create(rewriter, loc, aTileType, a, offsets, sizes, strides).getResult(); } static Value createPaddedInputCompute(Value input, RankedTensorType paddedInputType, ConversionPatternRewriter& rewriter, Location loc) { auto inputType = cast(input.getType()); if (inputType == paddedInputType) return input; auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) { Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc); spatial::SpatYieldOp::create(rewriter, loc, paddedInput); }); return computeOp.getResult(0); } static FailureOr createVmmBatch(Value a, Value b, RankedTensorType aType, RankedTensorType paddedBType, RankedTensorType partialPiecesType, int64_t numOutRows, int64_t numKSlices, int64_t numOutHSlices, ConversionPatternRewriter& rewriter, Location loc) { const int64_t laneCount = partialPiecesType.getDimSize(0); auto batchOp = createSpatComputeBatch( rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) { Value row = onnx_mlir::affineModConst(rewriter, loc, args.lane, numOutRows, rewriter.getInsertionBlock()->getParentOp()); Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc); Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc); auto aTileType = RankedTensorType::get({1, static_cast(crossbarSize.getValue())}, aType.getElementType()); auto bTileType = RankedTensorType::get( {static_cast(crossbarSize.getValue()), static_cast(crossbarSize.getValue())}, paddedBType.getElementType()); auto pieceType = RankedTensorType::get({1, static_cast(crossbarSize.getValue())}, partialPiecesType.getElementType()); Value aTile = extractATile(args.inputs.front(), row, kOffset, aTileType, rewriter, loc); SmallVector bOffsets {kOffset, hOffset}; SmallVector bSizes {rewriter.getIndexAttr(crossbarSize.getValue()), rewriter.getIndexAttr(crossbarSize.getValue())}; SmallVector unitStrides = getUnitStrides(rewriter, 2); Value bTile = tensor::ExtractSliceOp::create(rewriter, loc, bTileType, args.weights.front(), bOffsets, bSizes, unitStrides) .getResult(); Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult(); SmallVector pieceOffsets {args.lane, rewriter.getIndexAttr(0)}; SmallVector pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())}; createParallelInsertSliceIntoBatchOutput( rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, unitStrides); }); if (failed(batchOp)) return failure(); return *batchOp; } static Value createDynamicGemmBatchRow(Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) { if (numOutCols == 1) return lane; MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); return createOrFoldAffineApply( rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}, rewriter.getInsertionBlock()->getParentOp()); } static Value extractDynamicGemmBColumn( Value matrix, Value column, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) { SmallVector offsets {rewriter.getIndexAttr(0), column}; SmallVector sizes {rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)}; SmallVector strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; auto columnSliceType = RankedTensorType::get({vectorType.getDimSize(1), 1}, vectorType.getElementType()); Value columnSlice = tensor::ExtractSliceOp::create(rewriter, loc, columnSliceType, matrix, offsets, sizes, strides).getResult(); SmallVector collapseReassociation { ReassociationIndices {0, 1} }; auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType()); Value collapsed = tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, columnSlice, collapseReassociation).getResult(); SmallVector expandReassociation { ReassociationIndices {0, 1} }; return tensor::ExpandShapeOp::create(rewriter, loc, vectorType, collapsed, expandReassociation).getResult(); } static Value extractDynamicGemmRowVector( Value matrix, Value row, RankedTensorType vectorType, ConversionPatternRewriter& rewriter, Location loc) { SmallVector offsets {row, rewriter.getIndexAttr(0)}; SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))}; SmallVector strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; return tensor::ExtractSliceOp::create(rewriter, loc, vectorType, matrix, offsets, sizes, strides).getResult(); } static FailureOr verifyDynamicGemmBiasType(RankedTensorType cType, RankedTensorType outType) { if (!cType.hasStaticShape() || cType.getRank() > 2) return failure(); if (cType.getRank() == 0) return cType; int64_t numOutRows = outType.getDimSize(0); int64_t numOutCols = outType.getDimSize(1); if (cType.getRank() == 1) { int64_t cols = cType.getDimSize(0); if (cols == 1 || cols == numOutCols) return cType; return failure(); } int64_t rows = cType.getDimSize(0); int64_t cols = cType.getDimSize(1); if ((rows == 1 || rows == numOutRows) && (cols == 1 || cols == numOutCols)) return cType; return failure(); } static bool hasGemmBias(Value c) { Operation* definingOp = c.getDefiningOp(); return !definingOp || !isa(definingOp); } static Value createScalarTensorConstant(RankedTensorType scalarType, float value, ConversionPatternRewriter& rewriter, Location loc) { auto elementType = scalarType.getElementType(); auto scalarAttr = rewriter.getFloatAttr(elementType, value); auto denseAttr = DenseElementsAttr::get(scalarType, scalarAttr); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), denseAttr, scalarType); } static Value createBroadcastedBiasScalar(Value bias, RankedTensorType biasType, Value row, Value column, RankedTensorType scalarType, ConversionPatternRewriter& rewriter, Location loc) { SmallVector unitStrides(biasType.getRank(), rewriter.getIndexAttr(1)); if (biasType.getRank() == 1) { SmallVector offsets {biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(column)}; SmallVector sizes {rewriter.getIndexAttr(1)}; auto vectorType = RankedTensorType::get({1}, scalarType.getElementType()); Value vector = tensor::ExtractSliceOp::create(rewriter, loc, vectorType, bias, offsets, sizes, unitStrides).getResult(); SmallVector reassociation { ReassociationIndices {0, 1} }; return tensor::ExpandShapeOp::create(rewriter, loc, scalarType, vector, reassociation).getResult(); } if (biasType.getRank() == 2) { SmallVector offsets { biasType.getDimSize(0) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(row), biasType.getDimSize(1) == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(column)}; SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; return tensor::ExtractSliceOp::create(rewriter, loc, scalarType, bias, offsets, sizes, unitStrides).getResult(); } Value scalar = tensor::ExtractOp::create(rewriter, loc, bias, ValueRange {}).getResult(); return tensor::SplatOp::create(rewriter, loc, scalarType, scalar).getResult(); } static FailureOr createVvdmulBatch(Value a, Value b, RankedTensorType aType, RankedTensorType bType, RankedTensorType scalarPiecesType, RankedTensorType outType, ConversionPatternRewriter& rewriter, Location loc) { const int64_t numOutRows = outType.getDimSize(0); const int64_t numOutCols = outType.getDimSize(1); const int64_t reductionSize = aType.getDimSize(1); const int64_t laneCount = numOutRows * numOutCols; auto batchOp = createSpatComputeBatch( rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) { Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc); Value column = onnx_mlir::affineModConst(rewriter, loc, args.lane, numOutCols, rewriter.getInsertionBlock()->getParentOp()); auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType()); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc); Value bVector = extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc); Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult(); SmallVector outputOffsets {args.lane, rewriter.getIndexAttr(0)}; SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector unitStrides = getUnitStrides(rewriter, 2); createParallelInsertSliceIntoBatchOutput( rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, unitStrides); }); if (failed(batchOp)) return failure(); return *batchOp; } static FailureOr createDynamicGemmOutputCompute(Value scalarPieces, Value bias, RankedTensorType scalarPiecesType, RankedTensorType biasType, RankedTensorType outType, float alpha, float beta, ConversionPatternRewriter& rewriter, Location loc) { const int64_t laneCount = scalarPiecesType.getDimSize(0); const int64_t numOutCols = outType.getDimSize(1); SmallVector inputs {scalarPieces}; if (bias) inputs.push_back(bias); return createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) -> LogicalResult { Value pieces = blockArgs[0]; Value biasArg = bias ? blockArgs[1] : Value(); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult(); Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount); auto loop = buildNormalizedScfFor( rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit}, [&](OpBuilder&, Location nestedLoc, Value lane, ValueRange iterArgs, SmallVectorImpl& yielded) { Value outputAcc = iterArgs.front(); Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, nestedLoc); Value column = onnx_mlir::affineModConst(rewriter, nestedLoc, lane, numOutCols, rewriter.getInsertionBlock()->getParentOp()); SmallVector scalarOffsets {lane, rewriter.getIndexAttr(0)}; SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value scalar = tensor::ExtractSliceOp::create( rewriter, nestedLoc, scalarType, pieces, scalarOffsets, scalarSizes, unitStrides) .getResult(); if (alpha != 1.0f) { Value alphaTensor = createScalarTensorConstant(scalarType, alpha, rewriter, nestedLoc); scalar = spatial::SpatVMulOp::create(rewriter, nestedLoc, scalarType, scalar, alphaTensor).getResult(); } if (biasArg) { Value biasScalar = createBroadcastedBiasScalar(biasArg, biasType, row, column, scalarType, rewriter, nestedLoc); if (beta != 1.0f) { Value betaTensor = createScalarTensorConstant(scalarType, beta, rewriter, nestedLoc); biasScalar = spatial::SpatVMulOp::create(rewriter, nestedLoc, scalarType, biasScalar, betaTensor).getResult(); } scalar = spatial::SpatVAddOp::create(rewriter, nestedLoc, scalarType, scalar, biasScalar).getResult(); } SmallVector outputOffsets {row, column}; Value outputNext = tensor::InsertSliceOp::create(rewriter, nestedLoc, scalar, outputAcc, outputOffsets, scalarSizes, unitStrides) .getResult(); yielded.push_back(outputNext); return success(); }); if (failed(loop)) return failure(); spatial::SpatYieldOp::create(rewriter, loc, loop->results.front()); return success(); }); } static Value createPartialGroupOffset(Value hSlice, int64_t kSlice, int64_t numKSlices, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) { MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); return createOrFoldAffineApply(rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice}, rewriter.getInsertionBlock()->getParentOp()); } static Value extractReductionPiece(Value partialPiecesArg, Value hSlice, int64_t kSlice, RankedTensorType pieceType, int64_t numKSlices, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) { SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector pieceSizes {rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())}; SmallVector pieceOffsets { createPartialGroupOffset(hSlice, kSlice, numKSlices, numOutRows, rewriter, loc), rewriter.getIndexAttr(0)}; return tensor::ExtractSliceOp::create( rewriter, loc, pieceType, partialPiecesArg, pieceOffsets, pieceSizes, unitStrides) .getResult(); } static Value reducePartialPiecesForHSlice(Value partialPiecesArg, Value hSlice, RankedTensorType pieceType, int64_t numKSlices, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) { SmallVector activePieces; activePieces.reserve(numKSlices); for (int64_t kSlice = 0; kSlice < numKSlices; ++kSlice) activePieces.push_back( extractReductionPiece(partialPiecesArg, hSlice, kSlice, pieceType, numKSlices, numOutRows, rewriter, loc)); while (activePieces.size() > 1) { SmallVector nextPieces; nextPieces.reserve((activePieces.size() + 1) / 2); for (size_t pieceIndex = 0; pieceIndex + 1 < activePieces.size(); pieceIndex += 2) nextPieces.push_back( spatial::SpatVAddOp::create(rewriter, loc, pieceType, activePieces[pieceIndex], activePieces[pieceIndex + 1]) .getResult()); if (activePieces.size() % 2 != 0) nextPieces.push_back(activePieces.back()); activePieces = std::move(nextPieces); } return activePieces.front(); } static FailureOr createReductionCompute(Value partialPieces, Value bias, RankedTensorType partialPiecesType, RankedTensorType outType, RankedTensorType paddedOutType, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) { SmallVector inputs {partialPieces}; if (bias) inputs.push_back(bias); auto computeOp = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) -> LogicalResult { Value partialPiecesArg = blockArgs[0]; Value biasArg = bias ? blockArgs[1] : Value(); if (biasArg && cast(biasArg.getType()) != paddedOutType) biasArg = createZeroPaddedTensor(biasArg, paddedOutType, rewriter, loc); const int64_t numOutRows = outType.getDimSize(0); const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(1), crossbarSize.getValue()); auto pieceType = RankedTensorType::get({numOutRows, static_cast(crossbarSize.getValue())}, partialPiecesType.getElementType()); Value outputInit = tensor::EmptyOp::create(rewriter, loc, paddedOutType.getShape(), paddedOutType.getElementType()).getResult(); SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector pieceSizes {rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())}; auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value { Value reduced = reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc); Value hOffset = onnx_mlir::affineMulConst( rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp()); if (biasArg) { SmallVector biasOffsets {rewriter.getIndexAttr(0), hOffset}; Value biasSlice = tensor::ExtractSliceOp::create(rewriter, loc, pieceType, biasArg, biasOffsets, pieceSizes, unitStrides) .getResult(); reduced = spatial::SpatVAddOp::create(rewriter, loc, pieceType, reduced, biasSlice).getResult(); } SmallVector outputOffsets {rewriter.getIndexAttr(0), hOffset}; return tensor::InsertSliceOp::create(rewriter, loc, reduced, outputAcc, outputOffsets, pieceSizes, unitStrides) .getResult(); }; Value paddedOutput = outputInit; if (numOutHSlices == 1) { Value hSlice = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); paddedOutput = buildOutputSlice(outputInit, hSlice); } else { Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); Value cOutHSlices = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices); auto hLoop = buildNormalizedScfFor( rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit}, [&](OpBuilder&, Location, Value hSlice, ValueRange iterArgs, SmallVectorImpl& yielded) { yielded.push_back(buildOutputSlice(iterArgs.front(), hSlice)); return success(); }); if (failed(hLoop)) return failure(); paddedOutput = hLoop->results.front(); } Value result = paddedOutput; if (paddedOutType != outType) { SmallVector outputOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector outputSizes {rewriter.getIndexAttr(outType.getDimSize(0)), rewriter.getIndexAttr(outType.getDimSize(1))}; result = tensor::ExtractSliceOp::create(rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, unitStrides) .getResult(); } spatial::SpatYieldOp::create(rewriter, loc, result); return success(); }); return computeOp; } struct GemmToSpatialComputes : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const override; }; } // namespace LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, ONNXGemmOpAdaptor gemmOpAdaptor, ConversionPatternRewriter& rewriter) const { Location loc = gemmOp.getLoc(); Value a = gemmOpAdaptor.getA(); Value b = gemmOpAdaptor.getB(); Value c = gemmOpAdaptor.getC(); auto aType = dyn_cast(a.getType()); auto bType = dyn_cast(b.getType()); auto outType = dyn_cast(gemmOp.getY().getType()); if (!aType || !bType || !outType) return failure(); if (!aType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A"); return failure(); } if (!bType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B"); return failure(); } if (!outType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result"); return failure(); } if (aType.getRank() != 2) { pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm input A", aType.getRank(), {2}); return failure(); } if (bType.getRank() != 2) { pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm input B", bType.getRank(), {2}); return failure(); } if (outType.getRank() != 2) { pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm result", outType.getRank(), {2}); return failure(); } if (gemmOpAdaptor.getTransA()) { auto aShape = aType.getShape(); auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType(), aType.getEncoding()); a = ONNXTransposeOp::create(rewriter, loc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})).getResult(); aType = transposedType; } if (gemmOpAdaptor.getTransB()) { auto bShape = bType.getShape(); auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType(), bType.getEncoding()); b = ONNXTransposeOp::create(rewriter, loc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})).getResult(); bType = transposedType; } const int64_t numOutRows = outType.getDimSize(0); const int64_t numOutCols = outType.getDimSize(1); const int64_t reductionSize = aType.getDimSize(1); if (!isCompileTimeComputable(b)) { bool hasC = hasGemmBias(c); float alpha = gemmOpAdaptor.getAlpha().convertToFloat(); float beta = gemmOpAdaptor.getBeta().convertToFloat(); RankedTensorType biasType; if (hasC) { auto cType = dyn_cast(c.getType()); if (!cType || !cType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias"); return failure(); } auto verifiedBiasType = verifyDynamicGemmBiasType(cType, outType); if (failed(verifiedBiasType)) { gemmOp.emitOpError("requires Gemm bias C to be broadcastable to the output shape"); return failure(); } biasType = *verifiedBiasType; } if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize || bType.getDimSize(1) != numOutCols) { gemmOp.emitOpError("has inconsistent A, B, and output shapes"); return failure(); } const int64_t laneCount64 = numOutRows * numOutCols; if (laneCount64 > std::numeric_limits::max()) { gemmOp.emitOpError("requires Gemm dynamic batch lane count to fit in i32"); return failure(); } auto scalarPiecesType = RankedTensorType::get({laneCount64, 1}, outType.getElementType()); auto batchOp = createVvdmulBatch(a, b, aType, bType, scalarPiecesType, outType, rewriter, loc); if (failed(batchOp)) return failure(); auto outputCompute = createDynamicGemmOutputCompute( batchOp->getResult(0), hasC ? c : Value(), scalarPiecesType, biasType, outType, alpha, beta, rewriter, loc); if (failed(outputCompute)) return failure(); rewriter.replaceOp(gemmOp, outputCompute->getResults()); return success(); } auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc); if (failed(scaledB)) { gemmOp.emitOpError("requires constant Gemm input B when alpha is not 1.0"); return failure(); } b = *scaledB; bType = cast(b.getType()); if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize || bType.getDimSize(1) != numOutCols) { gemmOp.emitOpError("has inconsistent A, B, and output shapes after transpose handling"); return failure(); } const int64_t numKSlices = ceilIntegerDivide(reductionSize, crossbarSize.getValue()); const int64_t numOutHSlices = ceilIntegerDivide(numOutCols, crossbarSize.getValue()); const int64_t paddedReductionSize = numKSlices * static_cast(crossbarSize.getValue()); const int64_t paddedOutCols = numOutHSlices * static_cast(crossbarSize.getValue()); auto paddedBType = RankedTensorType::get({paddedReductionSize, paddedOutCols}, bType.getElementType()); auto paddedB = materializePaddedConstantMatrix(b, paddedBType, rewriter, loc); if (failed(paddedB)) { gemmOp.emitOpError("requires constant Gemm input B so tiled weights can be padded statically"); return failure(); } b = *paddedB; auto paddedAType = RankedTensorType::get({numOutRows, paddedReductionSize}, aType.getElementType()); a = createPaddedInputCompute(a, paddedAType, rewriter, loc); aType = paddedAType; Value bias; bool hasC = hasGemmBias(c); auto paddedOutType = RankedTensorType::get({numOutRows, paddedOutCols}, outType.getElementType()); if (hasC) { auto cType = dyn_cast(c.getType()); if (!cType || !cType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias"); return failure(); } auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc); if (failed(scaledC)) { gemmOp.emitOpError("requires constant Gemm bias C when beta is not 1.0"); return failure(); } c = *scaledC; auto preparedBias = prepareBias(c, outType, paddedOutType, rewriter, loc); if (failed(preparedBias)) { gemmOp.emitOpError("requires Gemm bias C to be broadcastable to the output shape"); return failure(); } bias = *preparedBias; } const int64_t laneCount64 = numOutHSlices * numKSlices * numOutRows; if (laneCount64 > std::numeric_limits::max()) { gemmOp.emitOpError("requires Gemm tiled batch lane count to fit in i32"); return failure(); } auto partialPiecesType = RankedTensorType::get({laneCount64, static_cast(crossbarSize.getValue())}, outType.getElementType()); auto batchOp = createVmmBatch(a, b, aType, paddedBType, partialPiecesType, numOutRows, numKSlices, numOutHSlices, rewriter, loc); if (failed(batchOp)) return failure(); auto reductionCompute = createReductionCompute( batchOp->getResult(0), bias, partialPiecesType, outType, paddedOutType, numKSlices, rewriter, loc); if (failed(reductionCompute)) return failure(); rewriter.replaceOp(gemmOp, reductionCompute->getResults()); return success(); } void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.insert(ctx); } } // namespace onnx_mlir