diff --git a/src/PIM/Common/CMakeLists.txt b/src/PIM/Common/CMakeLists.txt index effdd6d..81a02a6 100644 --- a/src/PIM/Common/CMakeLists.txt +++ b/src/PIM/Common/CMakeLists.txt @@ -1,4 +1,5 @@ add_pim_library(OMPimCommon + IR/AffineUtils.cpp IR/AddressAnalysis.cpp IR/BatchCoreUtils.cpp IR/ConstantUtils.cpp diff --git a/src/PIM/Common/IR/AffineUtils.cpp b/src/PIM/Common/IR/AffineUtils.cpp new file mode 100644 index 0000000..7fd3d23 --- /dev/null +++ b/src/PIM/Common/IR/AffineUtils.cpp @@ -0,0 +1,182 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Matchers.h" + +#include "AffineUtils.hpp" +#include "ConstantUtils.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +static FailureOr floorDivSigned(int64_t lhs, int64_t rhs) { + if (rhs <= 0) + return failure(); + + int64_t quotient = lhs / rhs; + int64_t remainder = lhs % rhs; + if (remainder != 0 && lhs < 0) + --quotient; + return quotient; +} + +static FailureOr ceilDivSigned(int64_t lhs, int64_t rhs) { + if (rhs <= 0) + return failure(); + + int64_t quotient = lhs / rhs; + int64_t remainder = lhs % rhs; + if (remainder != 0 && lhs > 0) + ++quotient; + return quotient; +} + +Value createOrFoldAffineApply( + RewriterBase& rewriter, Location loc, AffineMap map, ValueRange operands, Operation* constantAnchor) { + assert(constantAnchor && "expected a valid constant anchor"); + assert(map.getNumResults() == 1 && "affine.apply expects a single-result affine map"); + + SmallVector operandConstants; + operandConstants.reserve(operands.size()); + for (Value operand : operands) { + std::optional constantValue = matchConstantIndexValue(operand); + if (!constantValue) + return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); + operandConstants.push_back(rewriter.getIndexAttr(*constantValue)); + } + + SmallVector foldedResults; + if (succeeded(map.constantFold(operandConstants, foldedResults)) && foldedResults.size() == 1) + if (auto constantResult = dyn_cast(foldedResults.front())) + return getOrCreateIndexConstant(rewriter, constantAnchor, constantResult.getInt()); + + return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); +} + +Value createOrFoldAffineApply( + RewriterBase& rewriter, Location loc, AffineExpr expr, ValueRange dims, Operation* constantAnchor) { + AffineMap map = AffineMap::get(/*dimCount=*/dims.size(), /*symbolCount=*/0, expr); + return createOrFoldAffineApply(rewriter, loc, map, dims, constantAnchor); +} + +Value affineMulConst(RewriterBase& rewriter, Location loc, Value value, int64_t multiplier, Operation* constantAnchor) { + assert(constantAnchor && "expected a valid constant anchor"); + if (multiplier == 0) + return getOrCreateIndexConstant(rewriter, constantAnchor, 0); + if (multiplier == 1) + return value; + + AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext()); + return createOrFoldAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value}, constantAnchor); +} + +Value affineModConst(RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) { + assert(constantAnchor && "expected a valid constant anchor"); + assert(divisor > 0 && "expected a positive affine.mod divisor"); + if (divisor == 1) + return getOrCreateIndexConstant(rewriter, constantAnchor, 0); + + AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext()); + return createOrFoldAffineApply(rewriter, loc, d0 % divisor, ValueRange {value}, constantAnchor); +} + +Value affineFloorDivConst( + RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) { + assert(constantAnchor && "expected a valid constant anchor"); + assert(divisor > 0 && "expected a positive affine.floor_div divisor"); + if (divisor == 1) + return value; + + AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext()); + return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor); +} + +FailureOr evaluateAffineExpr(AffineExpr expr, ArrayRef dims, ArrayRef symbols) { + if (auto constant = dyn_cast(expr)) + return constant.getValue(); + if (auto dim = dyn_cast(expr)) { + unsigned position = dim.getPosition(); + if (position >= dims.size()) + return failure(); + return dims[position]; + } + if (auto symbol = dyn_cast(expr)) { + unsigned position = symbol.getPosition(); + if (position >= symbols.size()) + return failure(); + return symbols[position]; + } + + auto binary = dyn_cast(expr); + if (!binary) + return failure(); + + FailureOr lhs = evaluateAffineExpr(binary.getLHS(), dims, symbols); + FailureOr rhs = evaluateAffineExpr(binary.getRHS(), dims, symbols); + if (failed(lhs) || failed(rhs)) + return failure(); + + switch (binary.getKind()) { + case AffineExprKind::Add: return *lhs + *rhs; + case AffineExprKind::Mul: return *lhs * *rhs; + case AffineExprKind::FloorDiv: return floorDivSigned(*lhs, *rhs); + case AffineExprKind::CeilDiv: return ceilDivSigned(*lhs, *rhs); + case AffineExprKind::Mod: { + FailureOr div = floorDivSigned(*lhs, *rhs); + if (failed(div)) + return failure(); + return *lhs - *div * *rhs; + } + default: return failure(); + } +} + +FailureOr evaluateSingleResultAffineMap(AffineMap map, ArrayRef operands) { + if (map.getNumResults() != 1 || operands.size() != map.getNumInputs()) + return failure(); + + ArrayRef dims(operands.data(), map.getNumDims()); + ArrayRef symbols(operands.data() + map.getNumDims(), map.getNumSymbols()); + return evaluateAffineExpr(map.getResult(0), dims, symbols); +} + +FailureOr evaluateAffineApply(affine::AffineApplyOp affineApply, IndexValueResolver resolver) { + SmallVector operands; + operands.reserve(affineApply.getMapOperands().size()); + for (Value operand : affineApply.getMapOperands()) { + FailureOr folded = resolver(operand); + if (failed(folded)) + return failure(); + operands.push_back(*folded); + } + + return evaluateSingleResultAffineMap(affineApply.getAffineMap(), operands); +} + +bool isSingleResultSymbolFreeAffineMap(AffineMap map) { return map.getNumResults() == 1 && map.getNumSymbols() == 0; } + +bool isDimAndConstantAffineExpr(AffineExpr expr) { + switch (expr.getKind()) { + case AffineExprKind::Constant: + case AffineExprKind::DimId: return true; + case AffineExprKind::SymbolId: return false; + case AffineExprKind::Add: { + auto binaryExpr = cast(expr); + return isDimAndConstantAffineExpr(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS()); + } + case AffineExprKind::Mul: { + auto binaryExpr = cast(expr); + return (isa(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS())) + || (isa(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS())); + } + case AffineExprKind::FloorDiv: + case AffineExprKind::CeilDiv: + case AffineExprKind::Mod: { + auto binaryExpr = cast(expr); + return isa(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS()); + } + } + + llvm_unreachable("unexpected affine expression kind"); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/AffineUtils.hpp b/src/PIM/Common/IR/AffineUtils.hpp new file mode 100644 index 0000000..ed2d585 --- /dev/null +++ b/src/PIM/Common/IR/AffineUtils.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" + +#include "llvm/ADT/FunctionExtras.h" + +namespace onnx_mlir { + +using IndexValueResolver = llvm::function_ref(mlir::Value)>; + +mlir::Value createOrFoldAffineApply(mlir::RewriterBase& rewriter, + mlir::Location loc, + mlir::AffineMap map, + mlir::ValueRange operands, + mlir::Operation* constantAnchor); + +mlir::Value createOrFoldAffineApply(mlir::RewriterBase& rewriter, + mlir::Location loc, + mlir::AffineExpr expr, + mlir::ValueRange dims, + mlir::Operation* constantAnchor); + +mlir::Value affineMulConst(mlir::RewriterBase& rewriter, + mlir::Location loc, + mlir::Value value, + int64_t multiplier, + mlir::Operation* constantAnchor); + +mlir::Value affineModConst(mlir::RewriterBase& rewriter, + mlir::Location loc, + mlir::Value value, + int64_t divisor, + mlir::Operation* constantAnchor); + +mlir::Value affineFloorDivConst(mlir::RewriterBase& rewriter, + mlir::Location loc, + mlir::Value value, + int64_t divisor, + mlir::Operation* constantAnchor); + +llvm::FailureOr +evaluateAffineExpr(mlir::AffineExpr expr, llvm::ArrayRef dims, llvm::ArrayRef symbols = {}); + +llvm::FailureOr evaluateSingleResultAffineMap(mlir::AffineMap map, llvm::ArrayRef operands); + +llvm::FailureOr evaluateAffineApply(mlir::affine::AffineApplyOp affineApply, IndexValueResolver resolver); + +bool isSingleResultSymbolFreeAffineMap(mlir::AffineMap map); + +bool isDimAndConstantAffineExpr(mlir::AffineExpr expr); + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/ConstantUtils.cpp b/src/PIM/Common/IR/ConstantUtils.cpp index ca2db1f..fde70f3 100644 --- a/src/PIM/Common/IR/ConstantUtils.cpp +++ b/src/PIM/Common/IR/ConstantUtils.cpp @@ -1,10 +1,8 @@ -#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" -#include "mlir/IR/Matchers.h" #include "ConstantUtils.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -75,24 +73,27 @@ Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int6 return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType()); } -Value createAffineApplyOrFoldedConstant( - RewriterBase& rewriter, Location loc, AffineMap map, ValueRange operands, Operation* anchorOp) { - SmallVector operandConstants; - operandConstants.reserve(operands.size()); - for (Value operand : operands) { - APInt constantValue; - if (!matchPattern(operand, m_ConstantInt(&constantValue))) - return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); - operandConstants.push_back(rewriter.getIndexAttr(constantValue.getSExtValue())); - } +std::optional matchConstantIndexValue(Value value) { + if (!value || !value.getType().isIndex()) + return std::nullopt; - SmallVector foldedResults; - if (succeeded(map.constantFold(operandConstants, foldedResults))) { - if (auto constantResult = dyn_cast(foldedResults.front())) - return getOrCreateIndexConstant(rewriter, anchorOp, constantResult.getInt()); - } + if (auto constant = value.getDefiningOp()) + return constant.value(); - return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); + if (auto constant = value.getDefiningOp()) + if (auto intAttr = dyn_cast(constant.getValue()); intAttr && intAttr.getType().isIndex()) + return intAttr.getInt(); + + return std::nullopt; +} + +std::optional matchConstantIndexValue(OpFoldResult value) { + if (auto attr = dyn_cast(value)) + if (auto intAttr = dyn_cast(attr); intAttr && intAttr.getType().isIndex()) + return intAttr.getInt(); + if (auto operand = dyn_cast(value)) + return matchConstantIndexValue(operand); + return std::nullopt; } } // namespace onnx_mlir diff --git a/src/PIM/Common/IR/ConstantUtils.hpp b/src/PIM/Common/IR/ConstantUtils.hpp index ae87f4c..a0a96f5 100644 --- a/src/PIM/Common/IR/ConstantUtils.hpp +++ b/src/PIM/Common/IR/ConstantUtils.hpp @@ -1,11 +1,12 @@ #pragma once -#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/FoldUtils.h" +#include + namespace onnx_mlir { mlir::Block* getConstantInsertionBlock(mlir::Operation* anchorOp); @@ -22,10 +23,8 @@ mlir::Value getOrCreateIndexConstant(mlir::OperationFolder& folder, mlir::Operat mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value); -mlir::Value createAffineApplyOrFoldedConstant(mlir::RewriterBase& rewriter, - mlir::Location loc, - mlir::AffineMap map, - mlir::ValueRange operands, - mlir::Operation* anchorOp); +std::optional matchConstantIndexValue(mlir::Value value); + +std::optional matchConstantIndexValue(mlir::OpFoldResult value); } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.cpp index 400fc81..0033b72 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.cpp @@ -1,11 +1,6 @@ -#include "mlir/Dialect/Arith/IR/Arith.h" - -#include "llvm/ADT/APInt.h" - #include #include "IndexingUtils.hpp" -#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" using namespace mlir; @@ -47,46 +42,4 @@ FailureOr> normalizeAxesChecked(std::optional ax return normalizedAxes; } -Value createAffineApplyOrFoldedConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) { - AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr); - Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); - return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp); -} - -Value multiplyIndexByConstant(PatternRewriter& rewriter, Operation* anchorOp, Value value, int64_t multiplier) { - if (multiplier == 0) - return getOrCreateIndexConstant(rewriter, anchorOp, 0); - if (multiplier == 1) - return value; - - MLIRContext* context = rewriter.getContext(); - AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApplyOrFoldedConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value}); -} - -Value modIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) { - if (divisor == 1) - return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); - - MLIRContext* context = rewriter.getContext(); - AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApplyOrFoldedConstant(rewriter, loc, d0 % divisor, ValueRange {value}); -} - -Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) { - if (divisor == 1) - return value; - - MLIRContext* context = rewriter.getContext(); - AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}); -} - -Value getOrMaterializeIndexValue(PatternRewriter& rewriter, OpFoldResult value) { - if (auto attr = dyn_cast(value)) - return getOrCreateIndexConstant( - rewriter, rewriter.getInsertionBlock()->getParentOp(), cast(attr).getInt()); - return cast(value); -} - } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.hpp index 143be3a..f798e75 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.hpp @@ -1,11 +1,7 @@ #pragma once -#include "mlir/IR/AffineExpr.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Value.h" -#include "mlir/Interfaces/FoldInterfaces.h" #include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" @@ -21,21 +17,4 @@ int64_t normalizeIndex(int64_t index, int64_t dimSize); mlir::FailureOr> normalizeAxesChecked(std::optional axesAttr, int64_t rank); -mlir::Value createAffineApplyOrFoldedConstant(mlir::PatternRewriter& rewriter, - mlir::Location loc, - mlir::AffineExpr expr, - mlir::ValueRange operands); - -mlir::Value multiplyIndexByConstant(mlir::PatternRewriter& rewriter, - mlir::Operation* anchorOp, - mlir::Value value, - int64_t multiplier); - -mlir::Value modIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor); - -mlir::Value -floorDivIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor); - -mlir::Value getOrMaterializeIndexValue(mlir::PatternRewriter& rewriter, mlir::OpFoldResult value); - } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp b/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp index 0a9567a..f1e9771 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp @@ -11,6 +11,7 @@ #include +#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -27,8 +28,7 @@ static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) { } static bool hasConstantIndices(tensor::ExtractOp extractOp) { - return llvm::all_of(extractOp.getIndices(), - [](Value index) { return isa_and_nonnull(index.getDefiningOp()); }); + return llvm::all_of(extractOp.getIndices(), [](Value index) { return matchConstantIndexValue(index).has_value(); }); } static bool isStaticTensorResult(Operation* op) { diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index feefba9..8fc6080 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -14,6 +14,7 @@ #include #include "Common/IR/ConstantUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" @@ -60,8 +61,11 @@ static Value createGemmBatchKOffset( MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApplyOrFoldedConstant( - rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane}); + return createOrFoldAffineApply(rewriter, + loc, + (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), + ValueRange {lane}, + rewriter.getInsertionBlock()->getParentOp()); } static Value createGemmBatchHOffset(Value lane, @@ -75,8 +79,11 @@ static Value createGemmBatchHOffset(Value lane, MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApplyOrFoldedConstant( - rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane}); + return createOrFoldAffineApply(rewriter, + loc, + d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), + ValueRange {lane}, + rewriter.getInsertionBlock()->getParentOp()); } static Value @@ -259,7 +266,8 @@ static spatial::SpatComputeBatch createVmmBatch(Value a, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) { - Value row = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutRows); + Value row = + onnx_mlir::affineModConst(rewriter, loc, args.lane, numOutRows, rewriter.getInsertionBlock()->getParentOp()); Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc); Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc); @@ -297,7 +305,8 @@ createDynamicGemmBatchRow(Value lane, int64_t numOutCols, ConversionPatternRewri MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}); + return createOrFoldAffineApply( + rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}, rewriter.getInsertionBlock()->getParentOp()); } static Value extractDynamicGemmBColumn( @@ -429,7 +438,8 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) { Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc); - Value column = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutCols); + Value column = + onnx_mlir::affineModConst(rewriter, loc, args.lane, numOutCols, rewriter.getInsertionBlock()->getParentOp()); auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType()); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); @@ -477,7 +487,8 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces, Value lane = loop.getInductionVar(); Value outputAcc = loop.getRegionIterArgs().front(); Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc); - Value column = onnx_mlir::modIndexByConstant(rewriter, loc, lane, numOutCols); + Value column = + onnx_mlir::affineModConst(rewriter, loc, lane, numOutCols, rewriter.getInsertionBlock()->getParentOp()); SmallVector scalarOffsets {lane, rewriter.getIndexAttr(0)}; SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; @@ -515,8 +526,11 @@ static Value createPartialGroupOffset(Value hSlice, Location loc) { MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApplyOrFoldedConstant( - rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice}); + return createOrFoldAffineApply(rewriter, + loc, + d0 * (numKSlices * numOutRows) + kSlice * numOutRows, + ValueRange {hSlice}, + rewriter.getInsertionBlock()->getParentOp()); } static Value extractReductionPiece(Value partialPiecesArg, @@ -597,8 +611,8 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces, auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value { Value reduced = reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc); - Value hOffset = onnx_mlir::multiplyIndexByConstant( - rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue()); + Value hOffset = onnx_mlir::affineMulConst( + rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp()); if (biasArg) { SmallVector biasOffsets {rewriter.getIndexAttr(0), hOffset}; Value biasSlice = diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index 7037125..57c99b4 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -7,6 +7,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" @@ -276,7 +277,8 @@ static Value extractBatchedBTile(Value b, static Value getBatchLaneIndex( Value lane, int64_t numOutRows, int64_t numKSlices, int64_t numOutHSlices, PatternRewriter& rewriter, Location loc) { - return floorDivIndexByConstant(rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices); + return affineFloorDivConst( + rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices, rewriter.getInsertionBlock()->getParentOp()); } static spatial::SpatComputeBatch createBatchedVmmBatch(Value a, @@ -300,16 +302,15 @@ static spatial::SpatComputeBatch createBatchedVmmBatch(Value a, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) { - Value row = modIndexByConstant(rewriter, loc, args.lane, numOutRows); - Value outerLane = floorDivIndexByConstant(rewriter, loc, args.lane, numOutRows); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value row = affineModConst(rewriter, loc, args.lane, numOutRows, anchorOp); + Value outerLane = affineFloorDivConst(rewriter, loc, args.lane, numOutRows, anchorOp); Value batch = getBatchLaneIndex(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc); - Value sliceLane = modIndexByConstant(rewriter, loc, outerLane, numKSlices * numOutHSlices); - Value kSlice = modIndexByConstant(rewriter, loc, sliceLane, numKSlices); - Value hSlice = floorDivIndexByConstant(rewriter, loc, sliceLane, numKSlices); - Value kOffset = - multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice, crossbarSize.getValue()); - Value hOffset = - multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue()); + Value sliceLane = affineModConst(rewriter, loc, outerLane, numKSlices * numOutHSlices, anchorOp); + Value kSlice = affineModConst(rewriter, loc, sliceLane, numKSlices, anchorOp); + Value hSlice = affineFloorDivConst(rewriter, loc, sliceLane, numKSlices, anchorOp); + Value kOffset = affineMulConst(rewriter, loc, kSlice, crossbarSize.getValue(), anchorOp); + Value hOffset = affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), anchorOp); auto aTileType = RankedTensorType::get({1, static_cast(crossbarSize.getValue())}, aType.getElementType()); @@ -445,10 +446,11 @@ static spatial::SpatComputeBatch createBatchedVvdmulBatch(Value a, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) { - Value batch = floorDivIndexByConstant(rewriter, loc, args.lane, numOutRows * numOutCols); - Value batchLane = modIndexByConstant(rewriter, loc, args.lane, numOutRows * numOutCols); - Value row = floorDivIndexByConstant(rewriter, loc, batchLane, numOutCols); - Value column = modIndexByConstant(rewriter, loc, batchLane, numOutCols); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value batch = affineFloorDivConst(rewriter, loc, args.lane, numOutRows * numOutCols, anchorOp); + Value batchLane = affineModConst(rewriter, loc, args.lane, numOutRows * numOutCols, anchorOp); + Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp); + Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp); auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType()); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); @@ -491,10 +493,11 @@ static Value createBatchedDynamicOutputCompute(Value scalarPieces, Value lane = loop.getInductionVar(); Value outputAcc = loop.getRegionIterArgs().front(); - Value batch = floorDivIndexByConstant(rewriter, loc, lane, numOutRows * numOutCols); - Value batchLane = modIndexByConstant(rewriter, loc, lane, numOutRows * numOutCols); - Value row = floorDivIndexByConstant(rewriter, loc, batchLane, numOutCols); - Value column = modIndexByConstant(rewriter, loc, batchLane, numOutCols); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value batch = affineFloorDivConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp); + Value batchLane = affineModConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp); + Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp); + Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp); SmallVector scalarOffsets {lane, rewriter.getIndexAttr(0)}; SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; Value scalar = tensor::ExtractSliceOp::create( @@ -542,10 +545,9 @@ static Value extractBatchedReductionPiece(Value partialPiecesArg, int64_t numOutRows, PatternRewriter& rewriter, Location loc) { - Value batchOffset = multiplyIndexByConstant( - rewriter, rewriter.getInsertionBlock()->getParentOp(), batch, numOutRows * numKSlices * numOutHSlices); - Value hOffset = - multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, numKSlices * numOutRows); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value batchOffset = affineMulConst(rewriter, loc, batch, numOutRows * numKSlices * numOutHSlices, anchorOp); + Value hOffset = affineMulConst(rewriter, loc, hSlice, numKSlices * numOutRows, anchorOp); Value kOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice * numOutRows); Value batchAndHSlice = arith::AddIOp::create(rewriter, loc, batchOffset, hOffset); Value pieceOffset = arith::AddIOp::create(rewriter, loc, batchAndHSlice, kOffset); @@ -631,7 +633,7 @@ static Value createBatchedReductionCompute(Value partialPieces, {2} }); Value hOffset = - multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, crossbarSize.getValue()); + affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp()); SmallVector outputOffsets {batch, rewriter.getIndexAttr(0), hOffset}; SmallVector outputSizes { rewriter.getIndexAttr(1), rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())}; diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp index e89d7c2..bf4ddab 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp @@ -7,6 +7,7 @@ #include #include +#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" @@ -82,7 +83,7 @@ computeLaneIndex(Value lane, int64_t stride, int64_t dimSize, ConversionPatternR expr = expr.floorDiv(stride); if (dimSize != 1) expr = expr % dimSize; - return createAffineApplyOrFoldedConstant(rewriter, loc, expr, ValueRange {lane}); + return createOrFoldAffineApply(rewriter, loc, expr, ValueRange {lane}, rewriter.getInsertionBlock()->getParentOp()); } static FailureOr buildReduceMeanKeepdimsBatch(Value input, diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index 115771f..2639423 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -13,6 +13,8 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/Support/LogicalResult.h" +#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" @@ -56,48 +58,19 @@ static LogicalResult verifyStaticWeights(ComputeOpTy computeOp, StringRef kind) return success(); } -static bool isConstantIndexLike(Value value) { - APInt constantValue; - return matchPattern(value, m_ConstantInt(&constantValue)); -} - -static bool isSupportedLaneAffineExpr(AffineExpr expr) { - switch (expr.getKind()) { - case AffineExprKind::Constant: - case AffineExprKind::DimId: return true; - case AffineExprKind::SymbolId: return false; - case AffineExprKind::Add: { - auto binaryExpr = cast(expr); - return isSupportedLaneAffineExpr(binaryExpr.getLHS()) && isSupportedLaneAffineExpr(binaryExpr.getRHS()); - } - case AffineExprKind::Mul: { - auto binaryExpr = cast(expr); - return (isa(binaryExpr.getLHS()) && isSupportedLaneAffineExpr(binaryExpr.getRHS())) - || (isa(binaryExpr.getRHS()) && isSupportedLaneAffineExpr(binaryExpr.getLHS())); - } - case AffineExprKind::FloorDiv: - case AffineExprKind::CeilDiv: - case AffineExprKind::Mod: { - auto binaryExpr = cast(expr); - return isa(binaryExpr.getRHS()) && isSupportedLaneAffineExpr(binaryExpr.getLHS()); - } - } - llvm_unreachable("unexpected affine expression kind"); -} - static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) { - if (value == laneArg || isConstantIndexLike(value)) + if (value == laneArg || matchConstantIndexValue(value)) return true; auto affineApply = value.getDefiningOp(); if (affineApply) { - if (affineApply.getAffineMap().getNumResults() != 1 || affineApply.getAffineMap().getNumSymbols() != 0) + if (!isSingleResultSymbolFreeAffineMap(affineApply.getAffineMap())) return false; if (!llvm::all_of(affineApply.getMapOperands(), [&](Value operand) { return isSupportedLaneOffsetExpr(operand, laneArg); })) { return false; } - return isSupportedLaneAffineExpr(affineApply.getAffineMap().getResult(0)); + return isDimAndConstantAffineExpr(affineApply.getAffineMap().getResult(0)); } auto extractOp = value.getDefiningOp(); @@ -112,8 +85,8 @@ static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) { auto addOp = value.getDefiningOp(); if (!addOp) return false; - return (addOp.getLhs() == laneArg && isConstantIndexLike(addOp.getRhs())) - || (addOp.getRhs() == laneArg && isConstantIndexLike(addOp.getLhs())); + return (addOp.getLhs() == laneArg && matchConstantIndexValue(addOp.getRhs())) + || (addOp.getRhs() == laneArg && matchConstantIndexValue(addOp.getLhs())); } static LogicalResult diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index d33061a..4a56f9e 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -21,6 +21,7 @@ #include "MaterializeMergeSchedule.hpp" #include "Scheduling/ComputeInstanceUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -1053,7 +1054,7 @@ Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern } AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); - return createAffineApplyOrFoldedConstant(state.rewriter, loc, map, ValueRange {index}, state.func); + return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {index}, state.func); } Value createIndexedIndexValue( @@ -1295,47 +1296,16 @@ std::optional getLaneProjectedDim(ArrayRef offsets, Valu } static FailureOr evaluateProjectedOffsetForLane(OpFoldResult value, Value laneArg, uint32_t lane) { - if (auto attr = dyn_cast(value)) { - auto intAttr = dyn_cast(attr); - if (!intAttr) - return failure(); - return intAttr.getInt(); - } + if (std::optional constant = matchConstantIndexValue(value)) + return *constant; Value current = cast(value); if (current == laneArg) return static_cast(lane); - if (auto constant = current.getDefiningOp()) - return constant.value(); - - if (auto constant = current.getDefiningOp()) - if (auto intAttr = dyn_cast(constant.getValue())) - return intAttr.getInt(); - - if (auto affineApply = current.getDefiningOp()) { - AffineMap map = affineApply.getAffineMap(); - if (map.getNumResults() != 1) - return failure(); - - SmallVector operandConstants; - operandConstants.reserve(affineApply.getMapOperands().size()); - for (Value operand : affineApply.getMapOperands()) { - FailureOr folded = evaluateProjectedOffsetForLane(operand, laneArg, lane); - if (failed(folded)) - return failure(); - operandConstants.push_back(IntegerAttr::get(IndexType::get(current.getContext()), *folded)); - } - - SmallVector foldedResults; - if (failed(map.constantFold(operandConstants, foldedResults)) || foldedResults.size() != 1) - return failure(); - - auto constantResult = dyn_cast(foldedResults.front()); - if (!constantResult) - return failure(); - return constantResult.getInt(); - } + if (auto affineApply = current.getDefiningOp()) + return evaluateAffineApply(affineApply, + [&](Value operand) { return evaluateProjectedOffsetForLane(operand, laneArg, lane); }); return failure(); } @@ -3503,7 +3473,7 @@ Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targe int64_t laneCount = static_cast(targetClass.cpus.size()); AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1); - return createAffineApplyOrFoldedConstant(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}, state.func); + return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}, state.func); } Value createBatchClassRunSourceLane(MaterializerState& state, diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp index 128eb13..95c9c22 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/ComputeGraph.cpp @@ -21,6 +21,8 @@ #include "ComputeGraph.hpp" #include "ComputeInstanceUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Support/TypeUtilities.hpp" namespace onnx_mlir { @@ -148,57 +150,6 @@ static CrossbarWeight getOpaqueCrossbarWeight(Value value, std::optional evaluateAffineExpr(AffineExpr expr, ArrayRef dims, ArrayRef symbols) { - if (auto constant = dyn_cast(expr)) - return constant.getValue(); - if (auto dim = dyn_cast(expr)) { - unsigned position = dim.getPosition(); - if (position >= dims.size()) - return failure(); - return dims[position]; - } - if (auto symbol = dyn_cast(expr)) { - unsigned position = symbol.getPosition(); - if (position >= symbols.size()) - return failure(); - return symbols[position]; - } - - auto binary = dyn_cast(expr); - if (!binary) - return failure(); - - FailureOr lhs = evaluateAffineExpr(binary.getLHS(), dims, symbols); - FailureOr rhs = evaluateAffineExpr(binary.getRHS(), dims, symbols); - if (failed(lhs) || failed(rhs)) - return failure(); - - auto floorDiv = [](int64_t value, int64_t divisor) -> FailureOr { - if (divisor <= 0) - return failure(); - if (value >= 0) - return value / divisor; - return -((-value + divisor - 1) / divisor); - }; - - switch (binary.getKind()) { - case AffineExprKind::Add: return *lhs + *rhs; - case AffineExprKind::Mul: return *lhs * *rhs; - case AffineExprKind::FloorDiv: return floorDiv(*lhs, *rhs); - case AffineExprKind::CeilDiv: - if (*rhs <= 0) - return failure(); - return (*lhs + *rhs - 1) / *rhs; - case AffineExprKind::Mod: { - FailureOr div = floorDiv(*lhs, *rhs); - if (failed(div)) - return failure(); - return *lhs - *div * *rhs; - } - default: return failure(); - } -} - static FailureOr evaluateIndexLike(Value value, const DenseMap& bindings, std::optional lane, Value laneArg); @@ -222,12 +173,8 @@ evaluateIndexLike(Value value, const DenseMap& bindings, std::op if (auto it = bindings.find(value); it != bindings.end()) return it->second; - if (auto constant = value.getDefiningOp()) - return constant.value(); - - if (auto constant = value.getDefiningOp()) - if (auto intAttr = dyn_cast(constant.getValue())) - return intAttr.getInt(); + if (std::optional constant = matchConstantIndexValue(value)) + return *constant; if (auto extract = value.getDefiningOp()) { auto constant = extract.getTensor().getDefiningOp(); @@ -245,26 +192,11 @@ evaluateIndexLike(Value value, const DenseMap& bindings, std::op return failure(); } - auto affineApply = value.getDefiningOp(); - if (!affineApply) - return failure(); + if (auto affineApply = value.getDefiningOp()) + return evaluateAffineApply(affineApply, + [&](Value operand) { return evaluateIndexLike(operand, bindings, lane, laneArg); }); - AffineMap map = affineApply.getAffineMap(); - if (map.getNumResults() != 1) - return failure(); - - SmallVector operands; - operands.reserve(affineApply.getMapOperands().size()); - for (Value operand : affineApply.getMapOperands()) { - FailureOr folded = evaluateIndexLike(operand, bindings, lane, laneArg); - if (failed(folded)) - return failure(); - operands.push_back(*folded); - } - - ArrayRef dims(operands.data(), map.getNumDims()); - ArrayRef symbols(operands.data() + map.getNumDims(), map.getNumSymbols()); - return evaluateAffineExpr(map.getResult(0), dims, symbols); + return failure(); } static FailureOr> evaluateIndexList(ArrayRef values,