diff --git a/src/PIM/Common/IR/ConstantUtils.cpp b/src/PIM/Common/IR/ConstantUtils.cpp index 91d0989..bcfe306 100644 --- a/src/PIM/Common/IR/ConstantUtils.cpp +++ b/src/PIM/Common/IR/ConstantUtils.cpp @@ -40,6 +40,21 @@ Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, O return folder.getOrCreateConstant(hostBlock, arithDialect, value, type); } +Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, RewriterBase& rewriter) { + assert(anchorOp && "expected a valid anchor operation"); + Block* hostBlock = getHostConstantBlock(anchorOp); + for (Operation& op : *hostBlock) { + auto constantOp = dyn_cast(&op); + if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value) + continue; + return constantOp.getResult(); + } + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(hostBlock); + return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast(value)).getResult(); +} + Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) { return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder); } @@ -49,6 +64,11 @@ Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, Operation return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder); } +Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, RewriterBase& rewriter) { + Builder builder(anchorOp->getContext()); + return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), rewriter); +} + Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) { Builder builder(anchorOp->getContext()); return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder); diff --git a/src/PIM/Common/IR/ConstantUtils.hpp b/src/PIM/Common/IR/ConstantUtils.hpp index 4754a7d..d5ea918 100644 --- a/src/PIM/Common/IR/ConstantUtils.hpp +++ b/src/PIM/Common/IR/ConstantUtils.hpp @@ -1,10 +1,7 @@ #pragma once #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/FoldUtils.h" @@ -17,10 +14,17 @@ mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp, mlir::Type type, mlir::OperationFolder& folder); +mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp, + mlir::Attribute value, + mlir::Type type, + mlir::RewriterBase& rewriter); + mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder); mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder); +mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::RewriterBase& rewriter); + mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder); mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder); diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index c9f0893..96f5a46 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -1,3 +1,4 @@ +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -12,13 +13,12 @@ #include "Common/Common.hpp" #include "Common/PimCommon.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" -#include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; @@ -118,6 +118,7 @@ void ONNXToSpatialPass::runOnOperation() { preTarget.addLegalDialect(); preTarget.addIllegalOp(); @@ -156,6 +157,7 @@ void ONNXToSpatialPass::runOnOperation() { target.addLegalDialect(); target.addIllegalOp(); @@ -189,6 +191,7 @@ void ONNXToSpatialPass::runOnOperation() { earlyPostTarget.addLegalDialect(); @@ -203,6 +206,7 @@ void ONNXToSpatialPass::runOnOperation() { postTarget.addLegalDialect(); postTarget.addDynamicallyLegalOp( diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp index aeeb47c..795cd37 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.cpp @@ -1,6 +1,4 @@ -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Support/LLVM.h" #include "Common/IR/WeightUtils.hpp" @@ -13,17 +11,28 @@ using namespace mlir; namespace onnx_mlir { -void checkWeightsDirectlyExtracted(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) { - for (auto extractSlice : func.getOps()) { - auto source = getCompileTimeSource(extractSlice.getOperation()); - if (source && hasWeightAlways(source->source) && source->chainLength > 1) { +namespace { - diagnostics.report(extractSlice.getOperation(), - [](Operation* illegalOp) { illegalOp->emitOpError("Weight not directly extracted"); }); +void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) { + func.walk([&](Operation* op) { + if (!hasWeightAlways(op)) + return; + + for (Value result : op->getResults()) { + if (hasOnlySpatialMvmVmmWeightUses(result)) + continue; + + diagnostics.report(op, [&](Operation* illegalOp) { + illegalOp->emitOpError( + "weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights"); + }); + return; } - } + }); } +} // namespace + LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { pim::CappedDiagnosticReporter diagnostics; @@ -38,9 +47,7 @@ LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) { "non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute"); }); } - - checkWeightsDirectlyExtracted(funcOp, diagnostics); - + checkWeightUseChains(funcOp, diagnostics); diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed"); return success(!diagnostics.hasFailure()); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index 5284daf..db6033a 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -1,13 +1,19 @@ +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/IRMapping.h" #include "mlir/IR/Location.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" +#include +#include + +#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" @@ -25,11 +31,7 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr if (factor == 1.0f) return value; - auto constantOp = value.getDefiningOp(); - if (!constantOp) - return failure(); - - auto denseAttr = dyn_cast(constantOp.getValue()); + auto denseAttr = dyn_cast_or_null(getHostConstDenseElementsAttr(value)); if (!denseAttr) return failure(); @@ -65,254 +67,447 @@ static Value transposeForSpatial(Value value, return computeOp.getResult(0); } -static Value -expandRankOneBias(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) { - if (isCompileTimeComputable(value)) - return tensor::ExpandShapeOp::create(rewriter, - loc, - resultType, - value, - SmallVector { - {0, 1} - }); +static Value createIndexConstant(ConversionPatternRewriter& rewriter, int64_t value) { + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + return getOrCreateHostIndexConstant(anchorOp, value, rewriter); +} - auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) { - Value expanded = tensor::ExpandShapeOp::create(rewriter, - loc, - resultType, - input, - SmallVector { - {0, 1} - }); - spatial::SpatYieldOp::create(rewriter, loc, expanded); +static Value +createAffineApply(ConversionPatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) { + AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr); + return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); +} + +static Value +multiplyIndexByConstant(Value value, int64_t multiplier, ConversionPatternRewriter& rewriter, Location loc) { + if (multiplier == 0) + return createIndexConstant(rewriter, 0); + if (multiplier == 1) + return value; + + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + return createAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value}); +} + +static Value modIndexByConstant(Value value, int64_t divisor, ConversionPatternRewriter& rewriter, Location loc) { + if (divisor == 1) + return createIndexConstant(rewriter, 0); + + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + return createAffineApply(rewriter, loc, d0 % divisor, ValueRange {value}); +} + +static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) { + return modIndexByConstant(lane, numOutRows, rewriter, loc); +} + +static Value createGemmBatchKOffset( + Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) { + if (numKSlices == 1) + return createIndexConstant(rewriter, 0); + + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + return createAffineApply( + rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane}); +} + +static Value createGemmBatchHOffset( + Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) { + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + return createAffineApply( + rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane}); +} + +static Value +createZeroPaddedTensor(Value value, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) { + auto sourceType = cast(value.getType()); + SmallVector lowPads(sourceType.getRank(), rewriter.getIndexAttr(0)); + SmallVector highPads; + highPads.reserve(sourceType.getRank()); + for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape())) + highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim)); + + auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads); + auto* padBlock = new Block(); + for (int64_t i = 0; i < sourceType.getRank(); ++i) + padBlock->addArgument(rewriter.getIndexType(), loc); + padOp.getRegion().push_back(padBlock); + rewriter.setInsertionPointToStart(padBlock); + auto zero = arith::ConstantOp::create( + rewriter, loc, sourceType.getElementType(), rewriter.getZeroAttr(sourceType.getElementType())); + tensor::YieldOp::create(rewriter, loc, zero.getResult()); + rewriter.setInsertionPointAfter(padOp); + return padOp.getResult(); +} + +static FailureOr materializePaddedConstantMatrix(Value value, + RankedTensorType resultType, + ConversionPatternRewriter& rewriter, + Location loc) { + auto sourceType = cast(value.getType()); + if (sourceType == resultType) + return value; + + auto denseAttr = getHostConstDenseElementsAttr(value); + if (!denseAttr) + return failure(); + + auto denseType = dyn_cast(denseAttr.getType()); + if (!denseType || denseType.getRank() != 2 || !denseType.hasStaticShape()) + return failure(); + + ArrayRef sourceShape = denseType.getShape(); + ArrayRef resultShape = resultType.getShape(); + SmallVector sourceValues(denseAttr.getValues()); + Attribute zero = rewriter.getZeroAttr(resultType.getElementType()); + SmallVector resultValues(resultType.getNumElements(), zero); + + for (int64_t row = 0; row < sourceShape[0]; ++row) + for (int64_t col = 0; col < sourceShape[1]; ++col) + resultValues[row * resultShape[1] + col] = sourceValues[row * sourceShape[1] + col]; + + auto resultAttr = DenseElementsAttr::get(resultType, resultValues); + return arith::ConstantOp::create(rewriter, loc, resultType, resultAttr).getResult(); +} + +static FailureOr materializePaddedBroadcastedConstantTensor(Value value, + RankedTensorType resultType, + int64_t unpaddedColumns, + ConversionPatternRewriter& rewriter, + Location loc) { + auto denseAttr = getHostConstDenseElementsAttr(value); + if (!denseAttr) + return failure(); + + auto sourceType = dyn_cast(denseAttr.getType()); + if (!sourceType || !sourceType.hasStaticShape() || sourceType.getRank() > resultType.getRank()) + return failure(); + + ArrayRef sourceShape = sourceType.getShape(); + ArrayRef resultShape = resultType.getShape(); + SmallVector unpaddedResultShape(resultShape.begin(), resultShape.end()); + unpaddedResultShape.back() = unpaddedColumns; + + const int64_t rankOffset = static_cast(resultShape.size() - sourceShape.size()); + for (int64_t resultIndex = 0; resultIndex < static_cast(resultShape.size()); ++resultIndex) { + const int64_t sourceIndex = resultIndex - rankOffset; + if (sourceIndex < 0) + continue; + const int64_t sourceDim = sourceShape[sourceIndex]; + const int64_t resultDim = unpaddedResultShape[resultIndex]; + if (sourceDim != 1 && sourceDim != resultDim) + return failure(); + } + + SmallVector sourceValues(denseAttr.getValues()); + SmallVector sourceStrides = computeRowMajorStrides(sourceShape); + SmallVector resultStrides = computeRowMajorStrides(resultShape); + Attribute zero = rewriter.getZeroAttr(resultType.getElementType()); + + SmallVector resultValues; + resultValues.reserve(resultType.getNumElements()); + for (int64_t flatIndex = 0; flatIndex < resultType.getNumElements(); ++flatIndex) { + int64_t remaining = flatIndex; + SmallVector resultIndices(resultShape.size(), 0); + for (int64_t dim = 0; dim < static_cast(resultShape.size()); ++dim) { + resultIndices[dim] = resultStrides.empty() ? 0 : remaining / resultStrides[dim]; + remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim]; + } + + if (resultIndices.back() >= unpaddedColumns) { + resultValues.push_back(zero); + continue; + } + + int64_t sourceFlatIndex = 0; + for (int64_t resultIndex = 0; resultIndex < static_cast(resultShape.size()); ++resultIndex) { + const int64_t sourceIndex = resultIndex - rankOffset; + if (sourceIndex < 0) + continue; + + const int64_t sourceDim = sourceShape[sourceIndex]; + const int64_t mappedIndex = sourceDim == 1 ? 0 : resultIndices[resultIndex]; + sourceFlatIndex += mappedIndex * sourceStrides[sourceIndex]; + } + resultValues.push_back(sourceValues[sourceFlatIndex]); + } + + auto resultAttr = DenseElementsAttr::get(resultType, resultValues); + return arith::ConstantOp::create(rewriter, loc, resultType, resultAttr).getResult(); +} + +static FailureOr prepareBias(Value c, + RankedTensorType outType, + RankedTensorType paddedOutType, + ConversionPatternRewriter& rewriter, + Location loc) { + auto cType = cast(c.getType()); + if (!cType.hasStaticShape()) + return failure(); + + if (isCompileTimeComputable(c)) + return materializePaddedBroadcastedConstantTensor(c, paddedOutType, outType.getDimSize(1), rewriter, loc); + + if (cType != outType) + return failure(); + + return c; +} + +static Value extractATile( + Value a, Value row, Value kOffset, RankedTensorType aTileType, ConversionPatternRewriter& rewriter, Location loc) { + SmallVector offsets {row, kOffset}; + SmallVector sizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())}; + SmallVector strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + + return tensor::ExtractSliceOp::create(rewriter, loc, aTileType, a, offsets, sizes, strides).getResult(); +} + +static Value createPaddedInputCompute(Value input, + RankedTensorType paddedInputType, + ConversionPatternRewriter& rewriter, + Location loc) { + auto inputType = cast(input.getType()); + if (inputType == paddedInputType) + return input; + + auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) { + Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc); + spatial::SpatYieldOp::create(rewriter, loc, paddedInput); }); + return computeOp.getResult(0); } -struct GemmToManyGemv : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +static spatial::SpatComputeBatch createVmmBatch(Value a, + Value b, + RankedTensorType aType, + RankedTensorType paddedBType, + RankedTensorType partialPiecesType, + int64_t numOutRows, + int64_t numKSlices, + ConversionPatternRewriter& rewriter, + Location loc) { + const int64_t laneCount = partialPiecesType.getDimSize(0); + auto batchOp = spatial::SpatComputeBatch::create(rewriter, + loc, + TypeRange {partialPiecesType}, + rewriter.getI32IntegerAttr(static_cast(laneCount)), + ValueRange {b}, + ValueRange {a}); - LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, - ONNXGemmOpAdaptor gemmOpAdaptor, - ConversionPatternRewriter& rewriter) const override; -}; + SmallVector blockArgTypes {rewriter.getIndexType(), paddedBType, aType, partialPiecesType}; + SmallVector blockArgLocs(blockArgTypes.size(), loc); + Block* body = + rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + rewriter.setInsertionPointToEnd(body); -struct GemvToSpatialCompute : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + auto lane = batchOp.getLaneArgument(); + auto weight = batchOp.getWeightArgument(0); + auto input = batchOp.getInputArgument(0); + auto output = batchOp.getOutputArgument(0); + assert(lane && weight && input && output && "malformed Gemm compute_batch body"); - LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, - ONNXGemmOpAdaptor gemmOpAdaptor, - ConversionPatternRewriter& rewriter) const override; -}; + Value row = createGemmBatchRow(*lane, numOutRows, rewriter, loc); + Value kOffset = createGemmBatchKOffset(*lane, numOutRows, numKSlices, rewriter, loc); + Value hOffset = createGemmBatchHOffset(*lane, numOutRows, numKSlices, rewriter, loc); -struct GemmToSpatialComputeBatch : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + auto aTileType = RankedTensorType::get({1, static_cast(crossbarSize.getValue())}, aType.getElementType()); + auto bTileType = RankedTensorType::get( + {static_cast(crossbarSize.getValue()), static_cast(crossbarSize.getValue())}, + paddedBType.getElementType()); + auto pieceType = + RankedTensorType::get({1, static_cast(crossbarSize.getValue())}, partialPiecesType.getElementType()); + Value aTile = extractATile(*input, row, kOffset, aTileType, rewriter, loc); - LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, - ONNXGemmOpAdaptor gemmOpAdaptor, - ConversionPatternRewriter& rewriter) const override; -}; + SmallVector bOffsets {kOffset, hOffset}; + SmallVector bSizes {rewriter.getIndexAttr(crossbarSize.getValue()), + rewriter.getIndexAttr(crossbarSize.getValue())}; + SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + Value bTile = + tensor::ExtractSliceOp::create(rewriter, loc, bTileType, *weight, bOffsets, bSizes, unitStrides).getResult(); + Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult(); -static SmallVector materializeBatchRowSlices(Value matrix, - RankedTensorType matrixType, - ConversionPatternRewriter& rewriter, - Location loc) { - const int64_t numRows = matrixType.getDimSize(0); - auto rowType = RankedTensorType::get({1, matrixType.getDimSize(1)}, matrixType.getElementType()); - SmallVector resultTypes(static_cast(numRows), rowType); + auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc); + rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); + SmallVector pieceOffsets {*lane, rewriter.getIndexAttr(0)}; + SmallVector pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())}; + tensor::ParallelInsertSliceOp::create(rewriter, loc, piece, *output, pieceOffsets, pieceSizes, unitStrides); - if (isCompileTimeComputable(matrix)) { - auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrix); - return SmallVector(extractRowsOp->result_begin(), extractRowsOp->result_end()); + rewriter.setInsertionPointAfter(batchOp); + return batchOp; +} + +static Value createPartialGroupOffset(Value hSlice, + int64_t kSlice, + int64_t numKSlices, + int64_t numOutRows, + ConversionPatternRewriter& rewriter, + Location loc) { + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + return createAffineApply(rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice}); +} + +static Value extractReductionPiece(Value partialPiecesArg, + Value hSlice, + int64_t kSlice, + RankedTensorType pieceType, + int64_t numKSlices, + int64_t numOutRows, + ConversionPatternRewriter& rewriter, + Location loc) { + SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + SmallVector pieceSizes {rewriter.getIndexAttr(numOutRows), + rewriter.getIndexAttr(crossbarSize.getValue())}; + SmallVector pieceOffsets { + createPartialGroupOffset(hSlice, kSlice, numKSlices, numOutRows, rewriter, loc), rewriter.getIndexAttr(0)}; + return tensor::ExtractSliceOp::create( + rewriter, loc, pieceType, partialPiecesArg, pieceOffsets, pieceSizes, unitStrides) + .getResult(); +} + +static Value reducePartialPiecesForHSlice(Value partialPiecesArg, + Value hSlice, + RankedTensorType pieceType, + int64_t numKSlices, + int64_t numOutRows, + ConversionPatternRewriter& rewriter, + Location loc) { + SmallVector activePieces; + activePieces.reserve(numKSlices); + for (int64_t kSlice = 0; kSlice < numKSlices; ++kSlice) + activePieces.push_back( + extractReductionPiece(partialPiecesArg, hSlice, kSlice, pieceType, numKSlices, numOutRows, rewriter, loc)); + + while (activePieces.size() > 1) { + SmallVector nextPieces; + nextPieces.reserve((activePieces.size() + 1) / 2); + for (size_t pieceIndex = 0; pieceIndex + 1 < activePieces.size(); pieceIndex += 2) + nextPieces.push_back( + spatial::SpatVAddOp::create(rewriter, loc, pieceType, activePieces[pieceIndex], activePieces[pieceIndex + 1]) + .getResult()); + if (activePieces.size() % 2 != 0) + nextPieces.push_back(activePieces.back()); + activePieces = std::move(nextPieces); } - auto buildRowSlices = [&](Value matrixArg) { - auto extractRowsOp = spatial::SpatExtractRowsOp::create(rewriter, loc, TypeRange(resultTypes), matrixArg); - return SmallVector(extractRowsOp->result_begin(), extractRowsOp->result_end()); - }; + return activePieces.front(); +} - auto cloneBatchInputChainIntoSliceCompute = - [&](Value rootInput, SmallVector chainOps, Value rootValue) -> SmallVector { - auto sliceCompute = - createSpatCompute<1>(rewriter, loc, TypeRange(resultTypes), {}, ValueRange {rootInput}, [&](Value input) { - Value transformedMatrix = input; - if (!chainOps.empty()) { - IRMapping mapper; - mapper.map(rootValue, input); - for (Operation* chainOp : chainOps) - rewriter.clone(*chainOp, mapper); - transformedMatrix = cast(mapper.lookup(matrix)); - } - spatial::SpatYieldOp::create(rewriter, loc, buildRowSlices(transformedMatrix)); - }); - SmallVector rowSlices(sliceCompute->result_begin(), sliceCompute->result_end()); - return rowSlices; - }; +static spatial::SpatCompute createReductionCompute(Value partialPieces, + Value bias, + RankedTensorType partialPiecesType, + RankedTensorType outType, + RankedTensorType paddedOutType, + int64_t numKSlices, + ConversionPatternRewriter& rewriter, + Location loc) { + SmallVector inputs {partialPieces}; + if (bias) + inputs.push_back(bias); - SmallVector chainOps; - Value rootValue = matrix; - while (Operation* definingOp = rootValue.getDefiningOp()) { - if (auto rootCompute = dyn_cast(definingOp)) { - SmallVector reversedChainOps(chainOps.rbegin(), chainOps.rend()); - return cloneBatchInputChainIntoSliceCompute( - rootCompute.getResult(cast(rootValue).getResultNumber()), reversedChainOps, rootValue); + auto computeOp = createSpatCompute(rewriter, loc, TypeRange {outType}, {}, inputs, [&](ValueRange blockArgs) { + Value partialPiecesArg = blockArgs[0]; + Value biasArg = bias ? blockArgs[1] : Value(); + if (biasArg && cast(biasArg.getType()) != paddedOutType) + biasArg = createZeroPaddedTensor(biasArg, paddedOutType, rewriter, loc); + + const int64_t numOutRows = outType.getDimSize(0); + const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(1), crossbarSize.getValue()); + auto pieceType = RankedTensorType::get({numOutRows, static_cast(crossbarSize.getValue())}, + partialPiecesType.getElementType()); + + Value outputInit = + tensor::EmptyOp::create(rewriter, loc, paddedOutType.getShape(), paddedOutType.getElementType()).getResult(); + SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + SmallVector pieceSizes {rewriter.getIndexAttr(numOutRows), + rewriter.getIndexAttr(crossbarSize.getValue())}; + + auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value { + Value reduced = + reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc); + Value hOffset = multiplyIndexByConstant(hSlice, crossbarSize.getValue(), rewriter, loc); + if (biasArg) { + SmallVector biasOffsets {rewriter.getIndexAttr(0), hOffset}; + Value biasSlice = + tensor::ExtractSliceOp::create(rewriter, loc, pieceType, biasArg, biasOffsets, pieceSizes, unitStrides) + .getResult(); + reduced = spatial::SpatVAddOp::create(rewriter, loc, pieceType, reduced, biasSlice).getResult(); + } + + SmallVector outputOffsets {rewriter.getIndexAttr(0), hOffset}; + return tensor::InsertSliceOp::create(rewriter, loc, reduced, outputAcc, outputOffsets, pieceSizes, unitStrides) + .getResult(); + }; + + Value paddedOutput = outputInit; + if (numOutHSlices == 1) { + Value hSlice = createIndexConstant(rewriter, 0); + paddedOutput = buildOutputSlice(outputInit, hSlice); + } + else { + Value c0 = createIndexConstant(rewriter, 0); + Value c1 = createIndexConstant(rewriter, 1); + Value cOutHSlices = createIndexConstant(rewriter, numOutHSlices); + auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit}); + rewriter.setInsertionPointToStart(hLoop.getBody()); + + Value hSlice = hLoop.getInductionVar(); + Value outputAcc = hLoop.getRegionIterArgs().front(); + scf::YieldOp::create(rewriter, loc, buildOutputSlice(outputAcc, hSlice)); + + rewriter.setInsertionPointAfter(hLoop); + paddedOutput = hLoop.getResult(0); } - if (definingOp->getNumOperands() != 1) - break; - if (!isa(definingOp)) - break; + Value result = paddedOutput; + if (paddedOutType != outType) { + SmallVector outputOffsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; + SmallVector outputSizes {rewriter.getIndexAttr(outType.getDimSize(0)), + rewriter.getIndexAttr(outType.getDimSize(1))}; + result = + tensor::ExtractSliceOp::create(rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, unitStrides) + .getResult(); + } + spatial::SpatYieldOp::create(rewriter, loc, result); + }); - chainOps.push_back(definingOp); - rootValue = definingOp->getOperand(0); - } - - SmallVector reversedChainOps(chainOps.rbegin(), chainOps.rend()); - return cloneBatchInputChainIntoSliceCompute(rootValue, reversedChainOps, rootValue); + return computeOp; } +struct GemmToSpatialComputes : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXGemmOp gemmOp, + ONNXGemmOpAdaptor gemmOpAdaptor, + ConversionPatternRewriter& rewriter) const override; +}; + } // namespace -LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp, - ONNXGemmOpAdaptor gemmOpAdaptor, - ConversionPatternRewriter& rewriter) const { +LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, + ONNXGemmOpAdaptor gemmOpAdaptor, + ConversionPatternRewriter& rewriter) const { Location loc = gemmOp.getLoc(); Value a = gemmOpAdaptor.getA(); Value b = gemmOpAdaptor.getB(); Value c = gemmOpAdaptor.getC(); if (gemmOpAdaptor.getTransA()) { - gemmOp.emitOpError("requires transA=false before Gemm row decomposition"); + gemmOp.emitOpError("requires transA=false before tiled Spatial Gemm lowering"); return failure(); } - bool hasC = !isa(c.getDefiningOp()); - - auto aType = cast(a.getType()); - auto outType = cast(gemmOp.getY().getType()); - if (!aType.hasStaticShape()) { - pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A"); + auto aType = dyn_cast(a.getType()); + auto bType = dyn_cast(b.getType()); + auto outType = dyn_cast(gemmOp.getY().getType()); + if (!aType || !bType || !outType) return failure(); - } - if (!outType.hasStaticShape()) { - pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result"); - return failure(); - } - - const int64_t numOutRows = aType.getDimSize(0); - - // Only decompose when there are multiple rows to split - if (numOutRows <= 1) - return failure(); - - auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc); - if (failed(scaledB)) - return failure(); - b = *scaledB; - - RankedTensorType cType = nullptr; - bool cHasNumOutRows = false; - if (hasC) { - auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc); - if (failed(scaledC)) - return failure(); - c = *scaledC; - cType = cast(c.getType()); - // Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling - if (cType.getRank() == 1) { - auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); - c = expandRankOneBias(c, expandedType, rewriter, loc); - cType = expandedType; - } - if (!cType.hasStaticShape()) { - pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias"); - return failure(); - } - if (cType.getRank() != 2) { - pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2}); - return failure(); - } - cHasNumOutRows = cType.getDimSize(0) == numOutRows; - } - - auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType()); - SmallVector aSlices = materializeBatchRowSlices(a, aType, rewriter, loc); - SmallVector cSlices; - if (hasC && cHasNumOutRows) - cSlices = materializeBatchRowSlices(c, cType, rewriter, loc); - - SmallVector gemvOps; - gemvOps.reserve(static_cast(numOutRows)); - for (int64_t rowIdx = 0; rowIdx < numOutRows; rowIdx++) { - Value cSlice = c; - if (hasC) { - if (cHasNumOutRows) - cSlice = cSlices[static_cast(rowIdx)]; - else if (!isVectorShape(getTensorShape(c))) { - gemmOp.emitOpError("requires Gemm bias C to be vector-like when shared across decomposed rows"); - return failure(); - } - } - - auto gemvOp = ONNXGemmOp::create(rewriter, - loc, - outRowType, - aSlices[static_cast(rowIdx)], - b, - cSlice, - rewriter.getF32FloatAttr(1.0f), - rewriter.getF32FloatAttr(1.0f), - gemmOp.getTransAAttr(), - gemmOp.getTransBAttr()); - gemvOps.push_back(gemvOp.getY()); - } - - auto concatComputeOp = createSpatCompute(rewriter, loc, gemmOp.getType(), {}, gemvOps, [&](ValueRange gemvOpsArgs) { - spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, /*axis=*/0, gemvOpsArgs)); - }); - - rewriter.replaceOp(gemmOp, concatComputeOp); - return success(); -} - -LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, - ONNXGemmOpAdaptor gemmOpAdaptor, - ConversionPatternRewriter& rewriter) const { - Location gemmLoc = gemmOp.getLoc(); - Value a = gemmOpAdaptor.getA(); - Value b = gemmOpAdaptor.getB(); - Value c = gemmOpAdaptor.getC(); - Value out = gemmOp.getY(); - - float alpha = gemmOpAdaptor.getAlpha().convertToFloat(); - float beta = gemmOpAdaptor.getBeta().convertToFloat(); - bool transA = gemmOpAdaptor.getTransA(); - bool transB = gemmOpAdaptor.getTransB(); - - auto aType = cast(a.getType()); - auto bType = cast(b.getType()); - auto outType = cast(out.getType()); - - RankedTensorType cType = nullptr; - bool hasC = !isa(c.getDefiningOp()); - if (hasC) { - cType = cast(c.getType()); - // Expand rank-1 bias [N] to rank-2 [1, N] for uniform handling - if (cType.getRank() == 1) { - auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); - c = expandRankOneBias(c, expandedType, rewriter, gemmLoc); - cType = expandedType; - } - if (!cType.hasStaticShape()) { - pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias"); - return failure(); - } - if (cType.getRank() != 2) { - pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2}); - return failure(); - } - } - if (!aType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A"); return failure(); @@ -325,189 +520,29 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result"); return failure(); } - - if (!isVectorShape(aType.getShape()) || (hasC && !isVectorShape(cType.getShape()))) - // Not a gemv + if (aType.getRank() != 2) { + pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm input A", aType.getRank(), {2}); return failure(); - - if (transA) { - auto aShape = aType.getShape(); - auto transposedType = RankedTensorType::get({aShape[1], aShape[0]}, aType.getElementType()); - a = transposeForSpatial(a, transposedType, {1, 0}, rewriter, gemmLoc); - aType = cast(a.getType()); } - if (transB) { - auto bShape = bType.getShape(); - auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType()); - b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, gemmLoc); - bType = cast(b.getType()); + if (bType.getRank() != 2) { + pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm input B", bType.getRank(), {2}); + return failure(); } - - if (alpha != 1.0f) { - auto scaledB = materializeScaledConstantTensor(b, alpha, rewriter, gemmLoc); - if (failed(scaledB)) - return failure(); - b = *scaledB; - bType = cast(b.getType()); - alpha = 1.0f; - } - if (hasC && beta != 1.0f) { - auto scaledC = materializeScaledConstantTensor(c, beta, rewriter, gemmLoc); - if (failed(scaledC)) - return failure(); - c = *scaledC; - cType = cast(c.getType()); - beta = 1.0f; - } - - auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue()); - auto [bNumHSlices, bLastHSliceSize] = ceilIntegerDivideWithRemainder(bType.getDimSize(1), crossbarSize.getValue()); - auto bNumVSlices = aNumHSlices; - auto cNumHSlices = bNumHSlices; - auto cLastHSliceSize = bLastHSliceSize; - auto outNumHSlices = cNumHSlices; - auto outLastHSliceSize = cLastHSliceSize; - - const size_t coresPerVSlice = ceilIntegerDivide(bNumVSlices, crossbarCountInCore.getValue()); - - DenseMap> aHSlices = sliceVectorPerCrossbarPerCore(a, rewriter, gemmLoc); - - DenseMap>> bTiles = - tileMatrix(b, crossbarSize, crossbarSize, rewriter, gemmLoc); - - SmallVector cHSlices; - if (hasC && cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1) - c = broadcastToVector(c, bType.getDimSize(1), rewriter, gemmLoc); - if (hasC) - cHSlices = sliceVector(c, crossbarSize, rewriter, gemmLoc); - - RankedTensorType outHSliceType = - RankedTensorType::get({1, static_cast(crossbarSize)}, outType.getElementType()); - RankedTensorType outLastHSliceType = - RankedTensorType::get({1, static_cast(bLastHSliceSize)}, outType.getElementType()); - - SmallVector outHSlices; - outHSlices.reserve(outNumHSlices); - for (size_t outSliceId = 0; outSliceId < outNumHSlices; outSliceId++) { - RankedTensorType currOutHSliceType = outHSliceType; - if (outSliceId == outNumHSlices - 1 && outLastHSliceSize != 0) - currOutHSliceType = outLastHSliceType; - - SmallVector partialResults; - partialResults.reserve(coresPerVSlice); - for (size_t coreId = 0; coreId < coresPerVSlice; coreId++) { - SmallVector weights; - weights.reserve(aHSlices[coreId].size()); - - for (size_t aSliceId = 0; aSliceId < aHSlices[coreId].size(); aSliceId++) - weights.push_back(bTiles[outSliceId][coreId][aSliceId]); - - auto computeOp = - spatial::SpatCompute::create(rewriter, gemmLoc, TypeRange {currOutHSliceType}, weights, aHSlices[coreId]); - SmallVector blockArgTypes; - SmallVector blockArgLocs; - blockArgTypes.reserve(weights.size() + aHSlices[coreId].size()); - blockArgLocs.reserve(weights.size() + aHSlices[coreId].size()); - for (Value weight : weights) { - blockArgTypes.push_back(weight.getType()); - blockArgLocs.push_back(gemmLoc); - } - for (Value input : aHSlices[coreId]) { - blockArgTypes.push_back(input.getType()); - blockArgLocs.push_back(gemmLoc); - } - Block* body = - rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); - rewriter.setInsertionPointToEnd(body); - - SmallVector vmmOutputs; - vmmOutputs.reserve(aHSlices[coreId].size()); - for (auto aHSliceId : llvm::seq(0, aHSlices[coreId].size())) { - auto weightArg = computeOp.getWeightArgument(aHSliceId); - auto inputArg = computeOp.getInputArgument(aHSliceId); - if (!weightArg || !inputArg) - return failure(); - vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, *weightArg, *inputArg)); - } - if (vmmOutputs.empty()) { - gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs"); - return failure(); - } - - Value partialVmmSum = sumTensors(vmmOutputs, rewriter); - spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum); - rewriter.setInsertionPointAfter(computeOp); - - partialResults.push_back(computeOp->getResult(0)); - } - - if (hasC) { - Value cHSlice = cHSlices[outSliceId]; - partialResults.push_back(cHSlice); - } - - auto reduceComputeOp = - createSpatCompute(rewriter, gemmLoc, currOutHSliceType, {}, partialResults, [&](ValueRange blockArgs) { - SmallVector values(blockArgs.begin(), blockArgs.end()); - Value outHSlice = sumTensors(values, rewriter); - spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice); - }); - - outHSlices.push_back(reduceComputeOp.getResult(0)); - } - - auto concatComputeOp = - createSpatCompute(rewriter, gemmLoc, gemmOp.getType(), {}, outHSlices, [&](ValueRange blockArgs) { - spatial::SpatYieldOp::create(rewriter, gemmLoc, createSpatConcat(rewriter, gemmLoc, /*axis=*/1, blockArgs)); - }); - - rewriter.replaceOp(gemmOp, concatComputeOp); - return success(); -} - -LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp, - ONNXGemmOpAdaptor gemmOpAdaptor, - ConversionPatternRewriter& rewriter) const { - Location loc = gemmOp.getLoc(); - Value a = gemmOpAdaptor.getA(); - Value b = gemmOpAdaptor.getB(); - Value c = gemmOpAdaptor.getC(); - - if (gemmOpAdaptor.getTransA()) { - gemmOp.emitOpError("requires transA=false before batch Gemm lowering"); + if (outType.getRank() != 2) { + pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm result", outType.getRank(), {2}); return failure(); } - bool hasC = !isa(c.getDefiningOp()); - - auto aType = cast(a.getType()); - auto bType = cast(b.getType()); - auto outType = cast(gemmOp.getY().getType()); - if (!aType.hasStaticShape()) { - pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input A"); + if (!isCompileTimeComputable(b)) { + gemmOp.emitOpError("requires Gemm input B to be statically computed from constants"); return failure(); } - if (!bType.hasStaticShape()) { - pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm input B"); - return failure(); - } - if (!outType.hasStaticShape()) { - pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm result"); - return failure(); - } - - const int64_t numOutRows = aType.getDimSize(0); - if (numOutRows <= 1) - return failure(); - - // Only handle the single-tile case: K <= crossbarSize and N <= crossbarSize - if (aType.getDimSize(1) > static_cast(crossbarSize.getValue()) - || outType.getDimSize(1) > static_cast(crossbarSize.getValue())) - return failure(); auto scaledB = materializeScaledConstantTensor(b, gemmOpAdaptor.getAlpha().convertToFloat(), rewriter, loc); - if (failed(scaledB)) + if (failed(scaledB)) { + gemmOp.emitOpError("requires constant Gemm input B when alpha is not 1.0"); return failure(); + } b = *scaledB; bType = cast(b.getType()); @@ -517,86 +552,74 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp, b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc); bType = cast(b.getType()); } - (void) bType; - Value sharedBias; + const int64_t numOutRows = outType.getDimSize(0); + const int64_t numOutCols = outType.getDimSize(1); + const int64_t reductionSize = aType.getDimSize(1); + if (aType.getDimSize(0) != numOutRows || bType.getDimSize(0) != reductionSize || bType.getDimSize(1) != numOutCols) { + gemmOp.emitOpError("has inconsistent A, B, and output shapes after transpose handling"); + return failure(); + } + + const int64_t numKSlices = ceilIntegerDivide(reductionSize, crossbarSize.getValue()); + const int64_t numOutHSlices = ceilIntegerDivide(numOutCols, crossbarSize.getValue()); + const int64_t paddedReductionSize = numKSlices * static_cast(crossbarSize.getValue()); + const int64_t paddedOutCols = numOutHSlices * static_cast(crossbarSize.getValue()); + + auto paddedBType = RankedTensorType::get({paddedReductionSize, paddedOutCols}, bType.getElementType()); + auto paddedB = materializePaddedConstantMatrix(b, paddedBType, rewriter, loc); + if (failed(paddedB)) { + gemmOp.emitOpError("requires constant Gemm input B so tiled weights can be padded statically"); + return failure(); + } + b = *paddedB; + auto paddedAType = RankedTensorType::get({numOutRows, paddedReductionSize}, aType.getElementType()); + a = createPaddedInputCompute(a, paddedAType, rewriter, loc); + aType = paddedAType; + + Value bias; + bool hasC = !isa(c.getDefiningOp()); + auto paddedOutType = RankedTensorType::get({numOutRows, paddedOutCols}, outType.getElementType()); if (hasC) { - auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc); - if (failed(scaledC)) - return failure(); - c = *scaledC; - auto cType = cast(c.getType()); - if (cType.getRank() == 1) { - auto expandedType = RankedTensorType::get({1, cType.getDimSize(0)}, cType.getElementType()); - c = expandRankOneBias(c, expandedType, rewriter, loc); - cType = cast(c.getType()); - } - if (!cType.hasStaticShape()) { + auto cType = dyn_cast(c.getType()); + if (!cType || !cType.hasStaticShape()) { pim::emitUnsupportedStaticShapeDiagnostic(gemmOp, "Gemm bias"); return failure(); } - if (cType.getRank() != 2) { - pim::emitUnsupportedRankDiagnostic(gemmOp, "Gemm bias", cType.getRank(), {1, 2}); + + auto scaledC = materializeScaledConstantTensor(c, gemmOpAdaptor.getBeta().convertToFloat(), rewriter, loc); + if (failed(scaledC)) { + gemmOp.emitOpError("requires constant Gemm bias C when beta is not 1.0"); return failure(); } - // Row-specific bias can't share a single template body; fall through to GemmToManyGemv - if (cType.getDimSize(0) == numOutRows && numOutRows > 1) + c = *scaledC; + + auto preparedBias = prepareBias(c, outType, paddedOutType, rewriter, loc); + if (failed(preparedBias)) { + gemmOp.emitOpError("requires Gemm bias C to be broadcastable to the output shape"); return failure(); - if (cType.getDimSize(0) == 1 && cType.getDimSize(1) == 1) - c = broadcastToVector(c, outType.getDimSize(1), rewriter, loc); - sharedBias = c; + } + bias = *preparedBias; } - auto outRowType = RankedTensorType::get({1, outType.getDimSize(1)}, outType.getElementType()); - auto aRowType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType()); - auto batchOp = spatial::SpatComputeBatch::create(rewriter, - loc, - TypeRange {outType}, - rewriter.getI32IntegerAttr(static_cast(numOutRows)), - ValueRange {b}, - ValueRange {a}); - - SmallVector blockArgTypes {rewriter.getIndexType(), bType, aType, outType}; - SmallVector blockArgLocs(4, loc); - Block* body = - rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); - rewriter.setInsertionPointToEnd(body); - - auto lane = batchOp.getLaneArgument(); - auto weight = batchOp.getWeightArgument(0); - auto packedInput = batchOp.getInputArgument(0); - auto packedOutput = batchOp.getOutputArgument(0); - if (!lane || !weight || !packedInput || !packedOutput) + const int64_t laneCount64 = numOutHSlices * numKSlices * numOutRows; + if (laneCount64 > std::numeric_limits::max()) { + gemmOp.emitOpError("requires Gemm tiled batch lane count to fit in i32"); return failure(); + } - SmallVector inputOffsets {*lane, rewriter.getIndexAttr(0)}; - SmallVector inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))}; - SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - Value row = - tensor::ExtractSliceOp::create(rewriter, loc, aRowType, *packedInput, inputOffsets, inputSizes, unitStrides) - .getResult(); + auto partialPiecesType = + RankedTensorType::get({laneCount64, static_cast(crossbarSize.getValue())}, outType.getElementType()); + auto batchOp = createVmmBatch(a, b, aType, paddedBType, partialPiecesType, numOutRows, numKSlices, rewriter, loc); + auto reductionCompute = createReductionCompute( + batchOp.getResult(0), bias, partialPiecesType, outType, paddedOutType, numKSlices, rewriter, loc); - Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, *weight, row).getResult(); - Value laneResult = vmmResult; - if (sharedBias) - laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult(); - - auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc); - rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); - SmallVector outputOffsets {*lane, rewriter.getIndexAttr(0)}; - SmallVector outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))}; - tensor::ParallelInsertSliceOp::create( - rewriter, loc, laneResult, *packedOutput, outputOffsets, outputSizes, unitStrides); - rewriter.setInsertionPointAfter(batchOp); - - rewriter.replaceOp(gemmOp, batchOp.getResults()); + rewriter.replaceOp(gemmOp, reductionCompute.getResults()); return success(); } void populateGemmPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { - patterns.insert(ctx, PatternBenefit(2)); - patterns.insert(ctx); - patterns.insert(ctx); + patterns.insert(ctx); } } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescingPass.cpp b/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescingPass.cpp index 6454a3b..a1bd540 100644 --- a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescingPass.cpp +++ b/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescingPass.cpp @@ -67,7 +67,7 @@ static CoalescingReportRow getTotalRow(const CoalescingReportEntry& entry) { } static void emitReport(ArrayRef entries) { - std::fstream file = openReportFile("static_memory_coalescing_report"); + std::fstream file = openReportFile("memory_coalescing_report"); if (!file.is_open()) return; diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 3489ddf..bd4024c 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -49,6 +49,7 @@ def SpatCompute : SpatOp<"compute", insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc); std::optional> insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc); + ::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights(); ::mlir::FailureOr> insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc); }]; @@ -84,6 +85,7 @@ def SpatComputeBatch : SpatOp<"compute_batch", insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc); std::optional> insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc); + ::llvm::SetVector<::mlir::Value, ::llvm::SmallVector<::mlir::Value, 4>, ::llvm::SmallDenseSet<::mlir::Value, 4>> getCrossbarWeights(); ::mlir::FailureOr> insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc); }]; diff --git a/src/PIM/Dialect/Spatial/SpatialOps.cpp b/src/PIM/Dialect/Spatial/SpatialOps.cpp index 20a132b..c315219 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.cpp @@ -1,5 +1,6 @@ #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SetVector.h" #include @@ -35,6 +36,17 @@ void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t i cast(op).getProperties().setOperandSegmentSizes({weightCount, inputCount}); } +using CrossbarWeightSet = llvm::SetVector, llvm::SmallDenseSet>; + +CrossbarWeightSet collectCrossbarWeights(Region& body) { + CrossbarWeightSet weights; + body.walk([&](SpatVMMOp vmmOp) { + Value weight = vmmOp.getWeight(); + weights.insert(weight); + }); + return weights; +} + } // namespace std::optional SpatCompute::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), idx); } @@ -45,7 +57,6 @@ std::optional SpatCompute::getInputArgument(unsigned idx) { std::optional> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) { if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) { - llvm::dbgs() << "Disse netanyao\n"; auto index = std::distance(getWeights().begin(), existing); return { {*existing, *getWeightArgument(index)} @@ -75,6 +86,8 @@ std::optional> SpatCompute::insertInput(unsigne return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg); } +CrossbarWeightSet SpatCompute::getCrossbarWeights() { return collectCrossbarWeights(getBody()); } + FailureOr> SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) { if (idx > getNumResults()) @@ -127,7 +140,6 @@ std::optional> SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) { if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) { auto index = std::distance(getWeights().begin(), existing); - llvm::dbgs() << "Bum bum bum bum\n"; return { {*existing, *getWeightArgument(index)} }; @@ -156,6 +168,8 @@ std::optional> SpatComputeBatch::insertInput(un return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg); } +CrossbarWeightSet SpatComputeBatch::getCrossbarWeights() { return collectCrossbarWeights(getBody()); } + FailureOr> SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) { if (idx > getNumResults()) diff --git a/src/PIM/Dialect/Spatial/SpatialOps.hpp b/src/PIM/Dialect/Spatial/SpatialOps.hpp index 89069b5..7dc89fd 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.hpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.hpp @@ -10,6 +10,9 @@ #include "mlir/IR/Types.h" #include "mlir/Interfaces/ParallelCombiningOpInterface.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SetVector.h" + #include #include #include diff --git a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp index 8751dfa..2252586 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp @@ -10,6 +10,7 @@ #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp" using namespace mlir; @@ -239,6 +240,7 @@ void SpatCompute::print(OpAsmPrinter& printer) { printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); printer << " "; printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren); + printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size(); if (auto coreIdAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) printer << " coreId " << coreIdAttr.getInt(); @@ -264,6 +266,7 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { SmallVector weightTypes; SmallVector inputTypes; SmallVector outputTypes; + int32_t crossbarWeightCount = 0; int32_t coreId = 0; if (parseBoundValueList(parser, ListDelimiter::Square, weightArgs, weights)) @@ -273,9 +276,14 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { if (parseBoundValueList(parser, ListDelimiter::Paren, inputArgs, inputs)) return failure(); + bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights"); + if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount)) + return failure(); + bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id"); if (hasCoreId && parser.parseInteger(coreId)) return failure(); + (void) crossbarWeightCount; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedRepeatedList( @@ -357,6 +365,7 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) { printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square); printer << " "; printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren); + printer << " crossbarWeights " << collectDistinctCrossbarWeights(getOperation()).size(); if (getNumResults() != 0) { printer << " shared_outs"; @@ -395,6 +404,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) SmallVector weightTypes; SmallVector inputTypes; SmallVector outputTypes; + int32_t crossbarWeightCount = 0; SmallVector coreIds; if (parser.parseArgument(laneArg) || parser.parseEqual() || parser.parseInteger(lowerBound) @@ -413,9 +423,14 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) if (parseBlockArgumentList(parser, outputArgs)) return failure(); + bool hasCrossbarWeightCount = parseOptionalKeywordAlias(parser, "crossbarWeights", "crossbar_weights"); + if (hasCrossbarWeightCount && parser.parseInteger(crossbarWeightCount)) + return failure(); + bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids"); if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) return failure(); + (void) crossbarWeightCount; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parseCompressedOrTupleTypeList(parser, ListDelimiter::Square, weightTypes) diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index 364832d..df421fc 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -50,10 +50,9 @@ static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) { template static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind) { - for (Value weight : computeOp.getWeights()) { + for (Value weight : computeOp.getWeights()) if (!isCompileTimeComputable(weight)) return computeOp.emitOpError() << kind << " weights must be statically computed from constants"; - } return success(); } @@ -131,11 +130,9 @@ verifyStaticUnitStrideExtractSliceOp(tensor::ExtractSliceOp sliceOp, BlockArgume return sliceOp.emitOpError() << kind << " requires static slice sizes"; auto offsets = sliceOp.getOffsets(); - for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) { - bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset); - if (!supported) + for (Value offset : offsets) + if (!isSupportedLaneOffsetExpr(offset, laneArg)) return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets"; - } return success(); } @@ -155,11 +152,9 @@ static LogicalResult verifyStaticUnitStrideParallelInsertSliceOp(tensor::Paralle return sliceOp.emitOpError() << kind << " requires static slice sizes"; auto offsets = sliceOp.getOffsets(); - for (auto [offsetIndex, offset] : llvm::enumerate(offsets)) { - bool supported = offsetIndex == 0 ? isSupportedLaneOffsetExpr(offset, laneArg) : isConstantIndexLike(offset); - if (!supported) + for (Value offset : offsets) + if (!isSupportedLaneOffsetExpr(offset, laneArg)) return sliceOp.emitOpError() << kind << " requires simple lane-dependent offsets"; - } return success(); } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index ec97840..cc5ff13 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -1,8 +1,6 @@ #include "mlir/Analysis/TopologicalSortUtils.h" -#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" @@ -34,6 +32,7 @@ #include #include "MaterializeMergeSchedule.hpp" +#include "Scheduling/ComputeGraph.hpp" #include "Scheduling/ComputeInstanceUtils.hpp" #include "Scheduling/MergeSchedulingAnalysis.hpp" #include "src/Accelerators/PIM/Common/IR/CompactAsmUtils.hpp" @@ -282,11 +281,8 @@ void emitMotifProfile(func::FuncOp funcOp) { for (auto [index, compute] : llvm::enumerate(computes)) { ComputeMotifInfo& info = computeInfos[index]; - for (Operation& op : compute.getBody().front()) { - info.instructionCount++; - if (isa(&op)) - info.weightedVmmCount++; - } + info.instructionCount = spatial::countComputeBodyInstructions(compute.getBody()); + compute.getBody().walk([&](spatial::SpatVMMOp) { info.weightedVmmCount++; }); if (info.weightedVmmCount > 0) { weightedVmmNodeCount++; weightedVmmOpCount += info.weightedVmmCount; @@ -480,7 +476,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu struct ReportRow { uint64_t id = 0; uint64_t logicalComputeCount = 0; - uint64_t weightCount = 0; + uint64_t crossbarCount = 0; uint64_t instructionCount = 0; bool isRebatched = false; SmallVector coreIds; @@ -490,38 +486,40 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu uint64_t totalLogicalComputes = 0; uint64_t totalBatchComputeOps = 0; uint64_t totalInstructionCount = 0; - uint64_t totalWeightCount = 0; + uint64_t totalCrossbarCount = 0; uint64_t nextBatchId = 0; std::vector collectedData; + auto getPerInstanceCrossbarCount = [&](Operation* op) -> uint64_t { + return static_cast(spatial::collectDistinctCrossbarWeights(op).size()); + }; + for (Operation& op : funcOp.getBody().front()) { if (auto spatCompute = dyn_cast(&op)) { - uint64_t numInst = 0; - for (auto& _ : spatCompute.getRegion().front()) - ++numInst; + uint64_t numInst = spatial::countComputeBodyInstructions(spatCompute.getBody()); + uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(spatCompute.getOperation()); SmallVector coreIds; if (auto coreId = getComputeCoreId(spatCompute)) coreIds.push_back(*coreId); - collectedData.push_back({totalComputeOps++, 1, spatCompute.getWeights().size(), numInst, false, coreIds}); + collectedData.push_back({totalComputeOps++, 1, perInstanceCrossbarCount, numInst, false, coreIds}); totalLogicalComputes += 1; totalInstructionCount += numInst; - totalWeightCount += spatCompute.getWeights().size(); + totalCrossbarCount += perInstanceCrossbarCount; continue; } if (auto batch = dyn_cast(&op)) { - uint64_t numInst = 0; - for (auto& _ : batch.getRegion().front()) - ++numInst; + uint64_t numInst = spatial::countComputeBodyInstructions(batch.getBody()); uint64_t logicalCount = static_cast(batch.getLaneCount()); + uint64_t perInstanceCrossbarCount = getPerInstanceCrossbarCount(batch.getOperation()); SmallVector coreIds; if (auto coreIdsAttr = batch->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) llvm::append_range(coreIds, coreIdsAttr.asArrayRef()); - collectedData.push_back({nextBatchId++, logicalCount, batch.getWeights().size(), numInst, true, coreIds}); + collectedData.push_back({nextBatchId++, logicalCount, perInstanceCrossbarCount * logicalCount, numInst, true, coreIds}); totalComputeOps += 1; totalLogicalComputes += logicalCount; totalBatchComputeOps += 1; totalInstructionCount += numInst * logicalCount; - totalWeightCount += batch.getWeights().size(); + totalCrossbarCount += perInstanceCrossbarCount * logicalCount; } } @@ -531,7 +529,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu {"Number of logical computes", std::to_string(totalLogicalComputes) }, {"Number of top-level batch compute ops", std::to_string(totalBatchComputeOps) }, {"Number of instructions", std::to_string(totalInstructionCount)}, - {"Number of used crossbars", std::to_string(totalWeightCount) } + {"Number of used crossbars", std::to_string(totalCrossbarCount) } }; printReportTotalsBlock(os, totalFields); if (!collectedData.empty()) @@ -545,7 +543,7 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu for (uint64_t nI = cI + 1; nI < totalComputeOps; ++nI) { ReportRow next = collectedData[nI]; - if (current.isRebatched == next.isRebatched && current.weightCount == next.weightCount + if (current.isRebatched == next.isRebatched && current.crossbarCount == next.crossbarCount && current.instructionCount == next.instructionCount && current.logicalComputeCount == next.logicalComputeCount) lastIndex = nI; @@ -578,20 +576,20 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu os << ":\n"; uint64_t perCoreLogicalComputeCount = current.isRebatched ? 1 : current.logicalComputeCount; uint64_t perCoreInstructionCount = current.instructionCount; - uint64_t perCoreWeightCount = - current.logicalComputeCount == 0 ? 0 : current.weightCount / current.logicalComputeCount; + uint64_t perCoreCrossbarCount = + current.logicalComputeCount == 0 ? 0 : current.crossbarCount / current.logicalComputeCount; uint64_t totalEntryInstructionCount = current.instructionCount * current.logicalComputeCount; llvm::SmallVector perCoreFields = { {"Number of logical computes", std::to_string(perCoreLogicalComputeCount)}, {"Number of instructions", std::to_string(perCoreInstructionCount) }, - {"Number of used crossbars", std::to_string(perCoreWeightCount) } + {"Number of used crossbars", std::to_string(perCoreCrossbarCount) } }; if (current.isRebatched) { llvm::SmallVector totalEntryFields = { {"Number of logical computes", std::to_string(current.logicalComputeCount)}, {"Number of instructions", std::to_string(totalEntryInstructionCount) }, - {"Number of used crossbars", std::to_string(current.weightCount) } + {"Number of used crossbars", std::to_string(current.crossbarCount) } }; printReportPerCoreAndTotalFields(os, perCoreFields, totalEntryFields); } @@ -655,7 +653,7 @@ public: } emitMergeIrCounts("final-post-merge", func); dumpModule(cast(func->getParentOp()), "spatial1_dcp_merged"); - generateReport(func, "dcp_merge_report", analysisResult->cpuToLastComputeMap.size()); + generateReport(func, "spatial_merge_report", analysisResult->cpuToLastComputeMap.size()); } } }; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp index 0a079a7..184c309 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp @@ -1,4 +1,8 @@ +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Unit.h" @@ -9,12 +13,14 @@ #include #include +#include #include #include #include #include #include "ComputeGraph.hpp" +#include "ComputeInstanceUtils.hpp" #include "src/Support/TypeUtilities.hpp" namespace onnx_mlir { @@ -22,15 +28,42 @@ namespace spatial { using namespace mlir; +uint64_t countComputeBodyInstructions(Region& body); +uint64_t countComputeBodyOperationInstances(Region& body); + namespace { -Weight getComputeBodyWeight(Region& body) { - constexpr Weight kOperationWeight = 100; - Weight numOperations = 0; - for (auto& block : body) - for ([[maybe_unused]] auto& op : block) - numOperations = checkedAdd(numOperations, static_cast(1)); - return checkedMultiply(numOperations, kOperationWeight); +Cost getComputeBodyCost(Region& body) { + constexpr Cost kOperationCost = 100; + return checkedMultiply(static_cast(countComputeBodyOperationInstances(body)), kOperationCost); +} + +std::optional getStaticTripCount(scf::ForOp loop) { + auto lb = getConstantIntValue(loop.getLowerBound()); + auto ub = getConstantIntValue(loop.getUpperBound()); + auto step = getConstantIntValue(loop.getStep()); + if (!lb || !ub || !step || *step <= 0) + return std::nullopt; + if (*ub <= *lb) + return 0; + + uint64_t distance = static_cast(*ub - *lb); + uint64_t stride = static_cast(*step); + return (distance + stride - 1) / stride; +} + +uint64_t countOperationInstances(Operation& op) { + if (auto loop = dyn_cast(&op)) { + std::optional tripCount = getStaticTripCount(loop); + if (!tripCount) + return 1; + return checkedMultiply(countComputeBodyOperationInstances(loop.getRegion()), *tripCount); + } + + uint64_t instances = 1; + for (Region& region : op.getRegions()) + instances = checkedAdd(instances, countComputeBodyOperationInstances(region)); + return instances; } bool isUsedAsWeightOnly(Operation* producerOp) { @@ -61,7 +94,7 @@ bool isLaneOffset(OpFoldResult offset, Value laneArg) { return offsetValue == laneArg; } -std::optional getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) { +std::optional getBatchProjectedInputTransferCost(SpatComputeBatch batch, Value input) { auto inputIt = llvm::find(batch.getInputs(), input); if (inputIt == batch.getInputs().end()) return std::nullopt; @@ -72,7 +105,7 @@ std::optional getBatchProjectedInputTransferCost(SpatComputeBatch batch, if (!inputArg || !laneArg) return std::nullopt; - Weight projectedCost = 0; + Cost projectedCost = 0; for (Operation* user : inputArg->getUsers()) { auto extract = dyn_cast(user); if (!extract || extract.getSource() != *inputArg) @@ -83,7 +116,7 @@ std::optional getBatchProjectedInputTransferCost(SpatComputeBatch batch, auto resultType = dyn_cast(extract.getResult().getType()); if (!resultType || !resultType.hasStaticShape()) return std::nullopt; - projectedCost = checkedAdd(projectedCost, static_cast(getSizeInBytes(resultType))); + projectedCost = checkedAdd(projectedCost, static_cast(getSizeInBytes(resultType))); } if (projectedCost == 0) @@ -91,28 +124,286 @@ std::optional getBatchProjectedInputTransferCost(SpatComputeBatch batch, return projectedCost; } -Weight getInputTransferCost(const ComputeInstance& consumerInstance, Value input) { +Cost getInputTransferCost(const ComputeInstance& consumerInstance, Value input) { auto inputType = cast(input.getType()); if (auto batch = dyn_cast(consumerInstance.op)) - if (std::optional projectedCost = getBatchProjectedInputTransferCost(batch, input)) + if (std::optional projectedCost = getBatchProjectedInputTransferCost(batch, input)) return *projectedCost; - return static_cast(getSizeInBytes(inputType)); + return static_cast(getSizeInBytes(inputType)); +} + +static CrossbarWeight getOpaqueCrossbarWeight(Value value, std::optional lane) { + CrossbarWeight weight; + weight.opaqueValue = value; + weight.opaqueLane = lane.value_or(std::numeric_limits::max()); + return weight; +} + +static FailureOr evaluateAffineExpr(AffineExpr expr, ArrayRef dims, ArrayRef symbols) { + if (auto constant = dyn_cast(expr)) + return constant.getValue(); + if (auto dim = dyn_cast(expr)) { + unsigned position = dim.getPosition(); + if (position >= dims.size()) + return failure(); + return dims[position]; + } + if (auto symbol = dyn_cast(expr)) { + unsigned position = symbol.getPosition(); + if (position >= symbols.size()) + return failure(); + return symbols[position]; + } + + auto binary = dyn_cast(expr); + if (!binary) + return failure(); + + FailureOr lhs = evaluateAffineExpr(binary.getLHS(), dims, symbols); + FailureOr rhs = evaluateAffineExpr(binary.getRHS(), dims, symbols); + if (failed(lhs) || failed(rhs)) + return failure(); + + auto floorDiv = [](int64_t value, int64_t divisor) -> FailureOr { + if (divisor <= 0) + return failure(); + if (value >= 0) + return value / divisor; + return -((-value + divisor - 1) / divisor); + }; + + switch (binary.getKind()) { + case AffineExprKind::Add: return *lhs + *rhs; + case AffineExprKind::Mul: return *lhs * *rhs; + case AffineExprKind::FloorDiv: return floorDiv(*lhs, *rhs); + case AffineExprKind::CeilDiv: + if (*rhs <= 0) + return failure(); + return (*lhs + *rhs - 1) / *rhs; + case AffineExprKind::Mod: { + FailureOr div = floorDiv(*lhs, *rhs); + if (failed(div)) + return failure(); + return *lhs - *div * *rhs; + } + default: return failure(); + } +} + +static FailureOr +evaluateIndexLike(Value value, const DenseMap& bindings, std::optional lane, Value laneArg); + +static FailureOr evaluateIndexLike(OpFoldResult value, + const DenseMap& bindings, + std::optional lane, + Value laneArg) { + if (auto attr = llvm::dyn_cast(value)) { + auto intAttr = dyn_cast(attr); + if (!intAttr) + return failure(); + return intAttr.getInt(); + } + return evaluateIndexLike(llvm::cast(value), bindings, lane, laneArg); +} + +static FailureOr evaluateIndexLike(Value value, + const DenseMap& bindings, + std::optional lane, + Value laneArg) { + if (lane && value == laneArg) + return *lane; + if (auto it = bindings.find(value); it != bindings.end()) + return it->second; + + if (auto constant = value.getDefiningOp()) + return constant.value(); + + if (auto constant = value.getDefiningOp()) + if (auto intAttr = dyn_cast(constant.getValue())) + return intAttr.getInt(); + + if (auto extract = value.getDefiningOp()) { + auto constant = extract.getTensor().getDefiningOp(); + auto elements = constant ? dyn_cast(constant.getValue()) : nullptr; + auto shapedType = elements ? dyn_cast(elements.getType()) : nullptr; + if (!elements || !shapedType || shapedType.getRank() != 1 || extract.getIndices().size() != 1) + return failure(); + + FailureOr index = evaluateIndexLike(extract.getIndices().front(), bindings, lane, laneArg); + if (failed(index) || *index < 0 || *index >= static_cast(elements.getNumElements())) + return failure(); + + if (auto denseInts = dyn_cast(elements)) + return (*(denseInts.value_begin() + *index)).getSExtValue(); + return failure(); + } + + auto affineApply = value.getDefiningOp(); + if (!affineApply) + return failure(); + + AffineMap map = affineApply.getAffineMap(); + if (map.getNumResults() != 1) + return failure(); + + SmallVector operands; + operands.reserve(affineApply.getMapOperands().size()); + for (Value operand : affineApply.getMapOperands()) { + FailureOr folded = evaluateIndexLike(operand, bindings, lane, laneArg); + if (failed(folded)) + return failure(); + operands.push_back(*folded); + } + + ArrayRef dims(operands.data(), map.getNumDims()); + ArrayRef symbols(operands.data() + map.getNumDims(), map.getNumSymbols()); + return evaluateAffineExpr(map.getResult(0), dims, symbols); +} + +static FailureOr> +evaluateIndexList(ArrayRef values, + const DenseMap& bindings, + std::optional lane, + Value laneArg) { + SmallVector result; + result.reserve(values.size()); + for (OpFoldResult value : values) { + FailureOr folded = evaluateIndexLike(value, bindings, lane, laneArg); + if (failed(folded)) + return failure(); + result.push_back(*folded); + } + return result; +} + +static Value resolveCrossbarWeightRoot(Operation* owner, Value root) { + if (auto arg = dyn_cast(root)) { + if (auto compute = dyn_cast(owner)) { + for (auto [index, operand] : llvm::enumerate(compute.getWeights())) + if (compute.getWeightArgument(index) == arg) + return operand; + } + + if (auto batch = dyn_cast(owner)) { + for (auto [index, operand] : llvm::enumerate(batch.getWeights())) + if (batch.getWeightArgument(index) == arg) + return operand; + } + } + + return root; +} + +static CrossbarWeight completeCrossbarWeight(Value root, + SmallVector offsets, + SmallVector sizes, + SmallVector strides) { + CrossbarWeight weight; + weight.root = root; + if (auto constant = root.getDefiningOp()) + weight.rootAttr = static_cast(constant.getValue()); + weight.offsets = std::move(offsets); + weight.sizes = std::move(sizes); + weight.strides = std::move(strides); + return weight; +} + +static FailureOr +getStaticCrossbarWeight(Operation* owner, + Value value, + const DenseMap& bindings, + std::optional lane, + Value laneArg) { + if (auto extract = value.getDefiningOp()) { + FailureOr sourceWeight = + getStaticCrossbarWeight(owner, extract.getSource(), bindings, lane, laneArg); + auto offsets = evaluateIndexList(extract.getMixedOffsets(), bindings, lane, laneArg); + auto sizes = evaluateIndexList(extract.getMixedSizes(), bindings, lane, laneArg); + auto strides = evaluateIndexList(extract.getMixedStrides(), bindings, lane, laneArg); + if (failed(sourceWeight) || failed(offsets) || failed(sizes) || failed(strides)) + return failure(); + + if (sourceWeight->offsets.size() != offsets->size() || sourceWeight->sizes.size() != sizes->size() + || sourceWeight->strides.size() != strides->size()) { + return failure(); + } + + for (auto [index, offset] : llvm::enumerate(*offsets)) { + sourceWeight->offsets[index] += offset * sourceWeight->strides[index]; + sourceWeight->sizes[index] = (*sizes)[index]; + sourceWeight->strides[index] *= (*strides)[index]; + } + return *sourceWeight; + } + + Value root = resolveCrossbarWeightRoot(owner, value); + auto type = dyn_cast(root.getType()); + if (!type || !type.hasStaticShape()) + return failure(); + + SmallVector offsets(type.getRank(), 0); + SmallVector sizes(type.getShape().begin(), type.getShape().end()); + SmallVector strides(type.getRank(), 1); + return completeCrossbarWeight(root, std::move(offsets), std::move(sizes), std::move(strides)); +} + +static void addCrossbarWeight(CrossbarUsage& usage, CrossbarWeight weight) { + if (!containsCrossbarWeight(usage, weight)) + usage.push_back(std::move(weight)); +} + +static void collectCrossbarWeightsFromOp(Operation* op, + Operation* owner, + DenseMap& bindings, + CrossbarUsage& usage, + Value laneArg, + std::optional lane) { + if (auto loop = dyn_cast(op)) { + auto lb = getConstantIntValue(loop.getLowerBound()); + auto ub = getConstantIntValue(loop.getUpperBound()); + auto step = getConstantIntValue(loop.getStep()); + if (!lb || !ub || !step || *step <= 0) + return; + + for (int64_t iv = *lb; iv < *ub; iv += *step) { + bindings[loop.getInductionVar()] = iv; + for (Operation& nested : loop.getBody()->without_terminator()) + collectCrossbarWeightsFromOp(&nested, owner, bindings, usage, laneArg, lane); + } + bindings.erase(loop.getInductionVar()); + return; + } + + if (auto vmm = dyn_cast(op)) { + FailureOr weight = getStaticCrossbarWeight(owner, vmm.getWeight(), bindings, lane, laneArg); + if (failed(weight)) { + addCrossbarWeight(usage, getOpaqueCrossbarWeight(vmm.getWeight(), lane)); + return; + } + addCrossbarWeight(usage, *weight); + return; + } + + for (Region& region : op->getRegions()) + for (Block& block : region) + for (Operation& nested : block.without_terminator()) + collectCrossbarWeightsFromOp(&nested, owner, bindings, usage, laneArg, lane); } std::vector aggregateEdges(llvm::ArrayRef edges) { - llvm::DenseMap, Weight> edgeWeights; + llvm::DenseMap, Cost> edgeCosts; for (const ComputeGraphEdge& edge : edges) { if (edge.source == edge.target) continue; - auto inserted = edgeWeights.try_emplace({edge.source, edge.target}, edge.transferCost); + auto inserted = edgeCosts.try_emplace({edge.source, edge.target}, edge.transferCost); if (!inserted.second) inserted.first->second = std::max(inserted.first->second, edge.transferCost); } std::vector aggregatedEdges; - aggregatedEdges.reserve(edgeWeights.size()); - for (const auto& [key, weight] : edgeWeights) - aggregatedEdges.push_back({key.first, key.second, weight}); + aggregatedEdges.reserve(edgeCosts.size()); + for (const auto& [key, cost] : edgeCosts) + aggregatedEdges.push_back({key.first, key.second, cost}); llvm::sort(aggregatedEdges, [](const ComputeGraphEdge& lhs, const ComputeGraphEdge& rhs) { if (lhs.source != rhs.source) return lhs.source < rhs.source; @@ -123,30 +414,75 @@ std::vector aggregateEdges(llvm::ArrayRef ed } // namespace -Weight getComputeInstanceWeight(const ComputeInstance& instance) { +uint64_t countComputeBodyInstructions(Region& body) { + uint64_t numOperations = 0; + body.walk([&](Operation* op) { numOperations = checkedAdd(numOperations, static_cast(1)); }); + return numOperations; +} + +uint64_t countComputeBodyOperationInstances(Region& body) { + uint64_t instances = 0; + for (Block& block : body) + for (Operation& op : block) + instances = checkedAdd(instances, countOperationInstances(op)); + return instances; +} + +CrossbarUsage collectDistinctCrossbarWeights(Operation* owner, std::optional lane) { + CrossbarUsage usage; + DenseMap bindings; + Value laneArg; + if (auto batch = dyn_cast(owner)) + if (auto maybeLaneArg = batch.getLaneArgument()) + laneArg = *maybeLaneArg; + + for (Region& region : owner->getRegions()) + for (Block& block : region) + for (Operation& op : block.without_terminator()) + collectCrossbarWeightsFromOp(&op, owner, bindings, usage, laneArg, lane); + return usage; +} + +Cost getComputeInstanceCost(const ComputeInstance& instance) { if (auto spatCompute = dyn_cast(instance.op)) - return getComputeBodyWeight(spatCompute.getBody()); + return getComputeBodyCost(spatCompute.getBody()); auto batch = cast(instance.op); - return checkedMultiply(getComputeBodyWeight(batch.getBody()), static_cast(instance.laneCount)); + return checkedMultiply(getComputeBodyCost(batch.getBody()), static_cast(instance.laneCount)); } -CrossbarUsage getSpatComputeCrossbarUsage(SpatCompute spatComute){ - CrossbarUsage ret; - ret.insert_range(spatComute.getWeights()); - return ret; +bool containsCrossbarWeight(ArrayRef usage, const CrossbarWeight& weight) { + return llvm::is_contained(usage, weight); } -CrossbarUsage getSpatComputeBatchCrossbarUsage(SpatComputeBatch spatComuteBatch){ - CrossbarUsage ret; - ret.insert_range(spatComuteBatch.getWeights()); - return ret; +unsigned countCrossbarOverlap(ArrayRef lhs, ArrayRef rhs) { + unsigned overlap = 0; + for (const CrossbarWeight& weight : rhs) + if (containsCrossbarWeight(lhs, weight)) + ++overlap; + return overlap; +} + +size_t getCrossbarUnionSize(ArrayRef lhs, ArrayRef rhs) { + size_t size = lhs.size(); + for (const CrossbarWeight& weight : rhs) + if (!containsCrossbarWeight(lhs, weight)) + ++size; + return size; +} + +void insertCrossbarWeights(CrossbarUsage& usage, ArrayRef weights) { + for (const CrossbarWeight& weight : weights) + addCrossbarWeight(usage, weight); } CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance) { - if (auto spatCompute = dyn_cast(instance.op)) - return getSpatComputeCrossbarUsage(spatCompute); - auto batch = cast(instance.op); - return getSpatComputeBatchCrossbarUsage(batch); + CrossbarUsage usage; + if (isa(instance.op)) + return collectDistinctCrossbarWeights(instance.op); + + for (uint32_t lane = instance.laneStart; lane < instance.laneStart + instance.laneCount; ++lane) + insertCrossbarWeights(usage, collectDistinctCrossbarWeights(instance.op, lane)); + return usage; } ComputeGraph buildComputeGraph(Operation* entryOp) { @@ -161,7 +497,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) { ComputeInstance instance {spatCompute.getOperation(), 0, 1}; size_t index = graph.nodes.size(); graph.nodes.push_back( - {instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index}); + {instance, getComputeInstanceCost(instance), getComputeInstanceCrossbarUsage(instance), index}); graph.instanceToIndex[instance] = index; continue; } @@ -173,7 +509,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) { ComputeInstance instance = getBatchChunkForIndex(batch, chunkIndex); size_t index = graph.nodes.size(); graph.nodes.push_back( - {instance, getComputeInstanceWeight(instance), getComputeInstanceCrossbarUsage(instance), index}); + {instance, getComputeInstanceCost(instance), getComputeInstanceCrossbarUsage(instance), index}); graph.instanceToIndex[instance] = index; } } @@ -185,7 +521,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) { for (const auto& [targetIndex, node] : llvm::enumerate(graph.nodes)) { llvm::SmallVector inputs = getComputeInstanceInputs(node.instance); for (Value input : inputs) { - Weight transferCost = getInputTransferCost(node.instance, input); + Cost transferCost = getInputTransferCost(node.instance, input); if (auto producerBatch = dyn_cast_or_null(input.getDefiningOp()); producerBatch && producerBatch.getNumResults() != 0 && !isa(node.instance.op)) { for (uint32_t lane = 0; lane < static_cast(producerBatch.getLaneCount()); ++lane) { @@ -208,7 +544,7 @@ ComputeGraph buildComputeGraph(Operation* entryOp) { } std::vector aggregatedEdges = aggregateEdges(rawEdges); - graph.edges.append(aggregatedEdges.begin(), aggregatedEdges.end()); + graph.edges.insert(graph.edges.end(), aggregatedEdges.begin(), aggregatedEdges.end()); graph.successors.assign(graph.nodes.size(), {}); graph.predecessors.assign(graph.nodes.size(), {}); for (const ComputeGraphEdge& edge : graph.edges) { @@ -233,8 +569,8 @@ bool verifyAcyclic(const ComputeGraph& graph) { size_t node = readyNodes.front(); readyNodes.pop(); ++visited; - for (const auto& [child, weight] : graph.successors[node]) { - (void) weight; + for (const auto& [child, cost] : graph.successors[node]) { + (void) cost; assert(remainingParents[child] > 0 && "remaining parent count underflow"); if (--remainingParents[child] == 0) readyNodes.push(child); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp index 0882a55..a7c3c97 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.hpp @@ -1,52 +1,72 @@ #pragma once -#include "mlir/IR/Operation.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Value.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" -#include #include #include #include -#include "Utils.hpp" #include "ComputeInstance.hpp" -#include "ComputeInstanceUtils.hpp" +#include "Utils.hpp" -using CrossbarUsage = llvm::SmallPtrSet; +struct CrossbarWeight { + mlir::Value root; + mlir::Attribute rootAttr; + llvm::SmallVector offsets; + llvm::SmallVector sizes; + llvm::SmallVector strides; + mlir::Value opaqueValue; + uint32_t opaqueLane = 0; + + bool operator==(const CrossbarWeight& other) const { + bool sameRoot = rootAttr && other.rootAttr ? rootAttr == other.rootAttr : root == other.root; + return sameRoot && offsets == other.offsets && sizes == other.sizes && strides == other.strides + && opaqueValue == other.opaqueValue && opaqueLane == other.opaqueLane; + } +}; + +using CrossbarUsage = llvm::SmallVector; namespace onnx_mlir { namespace spatial { struct ComputeGraphNode { ComputeInstance instance; - Weight weight = 0; - llvm::SmallPtrSet crossbarUsage; + Cost cost = 0; + CrossbarUsage crossbarUsage; size_t originalOrder = 0; }; struct ComputeGraphEdge { size_t source = 0; size_t target = 0; - Weight transferCost = 0; + Cost transferCost = 0; }; struct ComputeGraph { - llvm::SmallVector nodes; - llvm::SmallVector edges; - std::vector>> successors; - std::vector>> predecessors; + std::vector nodes; + std::vector edges; + std::vector>> successors; + std::vector>> predecessors; llvm::DenseMap instanceToIndex; }; ComputeGraph buildComputeGraph(mlir::Operation* entryOp); bool verifyAcyclic(const ComputeGraph& graph); -Weight getComputeInstanceWeight(const ComputeInstance& instance); +uint64_t countComputeBodyInstructions(mlir::Region& body); +uint64_t countComputeBodyOperationInstances(mlir::Region& body); +Cost getComputeInstanceCost(const ComputeInstance& instance); +CrossbarUsage collectDistinctCrossbarWeights(mlir::Operation* owner, std::optional lane = std::nullopt); CrossbarUsage getComputeInstanceCrossbarUsage(const ComputeInstance& instance); +bool containsCrossbarWeight(llvm::ArrayRef usage, const CrossbarWeight& weight); +unsigned countCrossbarOverlap(llvm::ArrayRef lhs, llvm::ArrayRef rhs); +size_t getCrossbarUnionSize(llvm::ArrayRef lhs, llvm::ArrayRef rhs); +void insertCrossbarWeights(CrossbarUsage& usage, llvm::ArrayRef weights); } // namespace spatial } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp index 644367b..4fcef5d 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp @@ -48,12 +48,12 @@ void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result return lhs.second < rhs.second; }); - unsigned int usedCrossbars = 0; + CrossbarUsage usedCrossbars; for (size_t slot = 0; slot < scheduledTasks.size(); ++slot) { if (scheduledTasks[slot].first != slot) llvm::report_fatal_error("merge scheduling: CPU slots are not contiguous"); - usedCrossbars = addOrMax(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage.size()); - if (usedCrossbars > crossbarCapacity) + insertCrossbarWeights(usedCrossbars, graph.nodes[scheduledTasks[slot].second].crossbarUsage); + if (usedCrossbars.size() > crossbarCapacity) llvm::report_fatal_error("merge scheduling: CPU crossbar capacity exceeded"); } @@ -77,7 +77,7 @@ void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result if (sourceCpu == targetCpu && sourceSlot >= targetSlot) llvm::report_fatal_error("merge scheduling: same-CPU dependency order is invalid"); - Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].weight); + Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].cost); if (sourceCpu != targetCpu) earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost); if (targetStart < earliestTargetStart) { diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp index 30759b9..70d4cd5 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp @@ -89,8 +89,8 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu std::vector> reverseLevels = buildReverseLevels(graph); // MOCK: Replace this with your actual heterogeneous cost lookup. - // If graph.nodes[task] is modified to hold a vector of weights per processor, access it here. - auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].weight; }; + // If graph.nodes[task] is modified to hold a vector of costs per processor, access it here. + auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].cost; }; std::vector