diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index db6033a..b189bd0 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -2,6 +2,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/Support/LogicalResult.h" @@ -72,9 +73,37 @@ static Value createIndexConstant(ConversionPatternRewriter& rewriter, int64_t va return getOrCreateHostIndexConstant(anchorOp, value, rewriter); } +static std::optional getConstantIndexValue(Value value) { + if (auto constantIndex = value.getDefiningOp()) + return constantIndex.value(); + + APInt constantValue; + if (matchPattern(value, m_ConstantInt(&constantValue))) + return constantValue.getSExtValue(); + + return std::nullopt; +} + static Value createAffineApply(ConversionPatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) { AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr); + + SmallVector operandConstants; + operandConstants.reserve(operands.size()); + for (Value operand : operands) { + std::optional constantValue = getConstantIndexValue(operand); + if (!constantValue) + return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); + operandConstants.push_back(rewriter.getIndexAttr(*constantValue)); + } + + SmallVector foldedResults; + if (succeeded(map.constantFold(operandConstants, foldedResults))) { + auto constantResult = dyn_cast(foldedResults.front()); + if (constantResult) + return createIndexConstant(rewriter, constantResult.getInt()); + } + return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); } @@ -115,7 +144,15 @@ static Value createGemmBatchKOffset( } static Value createGemmBatchHOffset( - Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) { + Value lane, + int64_t numOutRows, + int64_t numKSlices, + int64_t numOutHSlices, + ConversionPatternRewriter& rewriter, + Location loc) { + if (numOutHSlices == 1) + return createIndexConstant(rewriter, 0); + MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); return createAffineApply( @@ -290,6 +327,7 @@ static spatial::SpatComputeBatch createVmmBatch(Value a, RankedTensorType partialPiecesType, int64_t numOutRows, int64_t numKSlices, + int64_t numOutHSlices, ConversionPatternRewriter& rewriter, Location loc) { const int64_t laneCount = partialPiecesType.getDimSize(0); @@ -314,7 +352,7 @@ static spatial::SpatComputeBatch createVmmBatch(Value a, Value row = createGemmBatchRow(*lane, numOutRows, rewriter, loc); Value kOffset = createGemmBatchKOffset(*lane, numOutRows, numKSlices, rewriter, loc); - Value hOffset = createGemmBatchHOffset(*lane, numOutRows, numKSlices, rewriter, loc); + Value hOffset = createGemmBatchHOffset(*lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc); auto aTileType = RankedTensorType::get({1, static_cast(crossbarSize.getValue())}, aType.getElementType()); auto bTileType = RankedTensorType::get( @@ -610,7 +648,8 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, auto partialPiecesType = RankedTensorType::get({laneCount64, static_cast(crossbarSize.getValue())}, outType.getElementType()); - auto batchOp = createVmmBatch(a, b, aType, paddedBType, partialPiecesType, numOutRows, numKSlices, rewriter, loc); + auto batchOp = + createVmmBatch(a, b, aType, paddedBType, partialPiecesType, numOutRows, numKSlices, numOutHSlices, rewriter, loc); auto reductionCompute = createReductionCompute( batchOp.getResult(0), bias, partialPiecesType, outType, paddedOutType, numKSlices, rewriter, loc); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index ed34d6b..d9441e1 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -1019,6 +1019,27 @@ std::optional getIndexedIndexPattern(ArrayRef valu return std::nullopt; } +Value createAffineApplyOrConstant( + MaterializerState& state, Location loc, AffineMap map, ValueRange operands, Operation* anchor) { + SmallVector operandConstants; + operandConstants.reserve(operands.size()); + for (Value operand : operands) { + auto constantValue = getConstantIntValue(operand); + if (!constantValue) + return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult(); + operandConstants.push_back(state.rewriter.getIndexAttr(*constantValue)); + } + + SmallVector foldedResults; + if (succeeded(map.constantFold(operandConstants, foldedResults))) { + auto constantResult = dyn_cast(foldedResults.front()); + if (constantResult) + return createIndexConstant(state, anchor, constantResult.getInt()); + } + + return affine::AffineApplyOp::create(state.rewriter, loc, map, operands).getResult(); +} + Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) { MLIRContext* context = state.func.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); @@ -1033,7 +1054,7 @@ Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern } AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); - return affine::AffineApplyOp::create(state.rewriter, loc, map, ValueRange {index}).getResult(); + return createAffineApplyOrConstant(state, loc, map, ValueRange {index}, state.func); } Value createIndexedIndexValue( @@ -3346,7 +3367,7 @@ Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targe int64_t laneCount = static_cast(targetClass.cpus.size()); AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1); - return affine::AffineApplyOp::create(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}).getResult(); + return createAffineApplyOrConstant(state, loc, map, ValueRange {slotIndex, *laneArg}, state.func); } Value createBatchClassRunSourceLane(MaterializerState& state,