diff --git a/src/PIM/Common/IR/AddressAnalysis.cpp b/src/PIM/Common/IR/AddressAnalysis.cpp index d687415..f49d26a 100644 --- a/src/PIM/Common/IR/AddressAnalysis.cpp +++ b/src/PIM/Common/IR/AddressAnalysis.cpp @@ -749,18 +749,12 @@ llvm::FailureOr compileContiguousAddressExprImpl(mlir::Valu } // namespace -llvm::FailureOr resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); } - llvm::FailureOr resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) { return resolveIndexValueImpl(value, &knowledge); } llvm::FailureOr compileIndexExpr(mlir::Value value) { return compileIndexValueImpl(value); } -llvm::FailureOr resolveContiguousAddress(mlir::Value value) { - return resolveContiguousAddressImpl(value, nullptr); -} - llvm::FailureOr resolveContiguousAddress(mlir::Value value, const StaticValueKnowledge& knowledge) { return resolveContiguousAddressImpl(value, &knowledge); diff --git a/src/PIM/Common/IR/AddressAnalysis.hpp b/src/PIM/Common/IR/AddressAnalysis.hpp index d2fead4..7a2240b 100644 --- a/src/PIM/Common/IR/AddressAnalysis.hpp +++ b/src/PIM/Common/IR/AddressAnalysis.hpp @@ -77,14 +77,12 @@ mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::m /// Resolves a value to contiguous backing storage when that storage can be /// proven statically from aliases, DPS ties, casts, and subviews. -llvm::FailureOr resolveContiguousAddress(mlir::Value value); llvm::FailureOr resolveContiguousAddress(mlir::Value value, - const StaticValueKnowledge& knowledge); + const StaticValueKnowledge& knowledge = {}); /// Statically evaluates index-like SSA values, including simple integer /// arithmetic and loop facts recorded in `knowledge`. -llvm::FailureOr resolveIndexValue(mlir::Value value); -llvm::FailureOr resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge); +llvm::FailureOr resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {}); llvm::FailureOr compileIndexExpr(mlir::Value value); /// Follows alias, view, and DPS chains to recover the backing value of a diff --git a/src/PIM/Common/IR/BatchCoreUtils.cpp b/src/PIM/Common/IR/BatchCoreUtils.cpp index 0baf7fc..3b7aa07 100644 --- a/src/PIM/Common/IR/BatchCoreUtils.cpp +++ b/src/PIM/Common/IR/BatchCoreUtils.cpp @@ -17,4 +17,18 @@ llvm::SmallVector getLaneChunkCoreIds(llvm::ArrayRef coreIds, return laneCoreIds; } +bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex) { + if (mlir::isa(op)) + return operandIndex == 3; + if (mlir::isa(op)) + return operandIndex == 1; + if (mlir::isa(op)) + return operandIndex == 2; + return false; +} + +bool isExplicitDevToHostTargetOperand(mlir::Operation* op, unsigned operandIndex) { + return mlir::isa(op) && operandIndex == 2; +} + } // namespace onnx_mlir diff --git a/src/PIM/Common/IR/BatchCoreUtils.hpp b/src/PIM/Common/IR/BatchCoreUtils.hpp index 41351d6..58eb57b 100644 --- a/src/PIM/Common/IR/BatchCoreUtils.hpp +++ b/src/PIM/Common/IR/BatchCoreUtils.hpp @@ -11,4 +11,8 @@ llvm::SmallVector getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp); llvm::SmallVector getLaneChunkCoreIds(llvm::ArrayRef coreIds, size_t laneCount, unsigned lane); +bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex); + +bool isExplicitDevToHostTargetOperand(mlir::Operation* op, unsigned operandIndex); + } // namespace onnx_mlir diff --git a/src/PIM/Common/IR/ConstantUtils.cpp b/src/PIM/Common/IR/ConstantUtils.cpp index 31e2ea3..ca2db1f 100644 --- a/src/PIM/Common/IR/ConstantUtils.cpp +++ b/src/PIM/Common/IR/ConstantUtils.cpp @@ -14,13 +14,17 @@ using namespace mlir; namespace onnx_mlir { -Block* getHostConstantBlock(Operation* anchorOp) { +Block* getConstantInsertionBlock(Operation* anchorOp) { assert(anchorOp && "expected a valid anchor operation"); for (Operation* current = anchorOp; current; current = current->getParentOp()) if (isa(current)) return current->getBlock(); + if (auto funcOp = dyn_cast(anchorOp)) + return &funcOp.getBody().front(); + if (auto moduleOp = dyn_cast(anchorOp)) + return moduleOp.getBody(); if (auto funcOp = anchorOp->getParentOfType()) return &funcOp.getBody().front(); if (auto moduleOp = anchorOp->getParentOfType()) @@ -28,9 +32,9 @@ Block* getHostConstantBlock(Operation* anchorOp) { return anchorOp->getBlock(); } -Value getOrCreateHostConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) { +Value getOrCreateConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) { assert(anchorOp && "expected a valid anchor operation"); - Block* hostBlock = getHostConstantBlock(anchorOp); + Block* hostBlock = getConstantInsertionBlock(anchorOp); for (Operation& op : *hostBlock) { auto constantOp = dyn_cast(&op); if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value) @@ -42,9 +46,9 @@ Value getOrCreateHostConstant(OperationFolder& folder, Operation* anchorOp, Attr return folder.getOrCreateConstant(hostBlock, arithDialect, value, type); } -Value getOrCreateHostConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) { +Value getOrCreateConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) { assert(anchorOp && "expected a valid anchor operation"); - Block* hostBlock = getHostConstantBlock(anchorOp); + Block* hostBlock = getConstantInsertionBlock(anchorOp); for (Operation& op : *hostBlock) { auto constantOp = dyn_cast(&op); if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value) @@ -57,28 +61,18 @@ Value getOrCreateHostConstant(RewriterBase& rewriter, Operation* anchorOp, Attri return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast(value)).getResult(); } -Value getOrCreateHostConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) { - return getOrCreateHostConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType()); +Value getOrCreateConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) { + return getOrCreateConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType()); } -Value getOrCreateHostIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) { +Value getOrCreateIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) { Builder builder(anchorOp->getContext()); - return getOrCreateHostConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType() ); + return getOrCreateConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType()); } -Value getOrCreateHostIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) { +Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) { Builder builder(anchorOp->getContext()); - return getOrCreateHostConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType()); -} - -Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) { - Builder builder(anchorOp->getContext()); - return getOrCreateHostConstant(folder, anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type() ); -} - -Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) { - Builder builder(anchorOp->getContext()); - return getOrCreateHostConstant(folder, anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type() ); + return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType()); } Value createAffineApplyOrFoldedConstant( @@ -95,7 +89,7 @@ Value createAffineApplyOrFoldedConstant( SmallVector foldedResults; if (succeeded(map.constantFold(operandConstants, foldedResults))) { if (auto constantResult = dyn_cast(foldedResults.front())) - return getOrCreateHostIndexConstant(rewriter, anchorOp, constantResult.getInt()); + return getOrCreateIndexConstant(rewriter, anchorOp, constantResult.getInt()); } return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); diff --git a/src/PIM/Common/IR/ConstantUtils.hpp b/src/PIM/Common/IR/ConstantUtils.hpp index c241fc9..e496a20 100644 --- a/src/PIM/Common/IR/ConstantUtils.hpp +++ b/src/PIM/Common/IR/ConstantUtils.hpp @@ -8,27 +8,23 @@ namespace onnx_mlir { -mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp); +mlir::Block* getConstantInsertionBlock(mlir::Operation* anchorOp); -mlir::Value getOrCreateHostConstant(mlir::OperationFolder& folder, - mlir::Operation* anchorOp, - mlir::Attribute value, - mlir::Type type); +mlir::Value getOrCreateConstant(mlir::OperationFolder& folder, + mlir::Operation* anchorOp, + mlir::Attribute value, + mlir::Type type); -mlir::Value getOrCreateHostConstant(mlir::RewriterBase& rewriter, - mlir::Operation* anchorOp, - mlir::Attribute value, - mlir::Type type); +mlir::Value getOrCreateConstant(mlir::RewriterBase& rewriter, + mlir::Operation* anchorOp, + mlir::Attribute value, + mlir::Type type); -mlir::Value getOrCreateHostConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp); +mlir::Value getOrCreateConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp); -mlir::Value getOrCreateHostIndexConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value); +mlir::Value getOrCreateIndexConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value); -mlir::Value getOrCreateHostIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value); - -mlir::Value getOrCreateHostI32Constant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int32_t value); - -mlir::Value getOrCreateHostI64Constant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value); +mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value); mlir::Value createAffineApplyOrFoldedConstant(mlir::RewriterBase& rewriter, mlir::Location loc, diff --git a/src/PIM/Common/IR/WeightUtils.cpp b/src/PIM/Common/IR/WeightUtils.cpp index 1d3b7be..4b3d678 100644 --- a/src/PIM/Common/IR/WeightUtils.cpp +++ b/src/PIM/Common/IR/WeightUtils.cpp @@ -178,10 +178,6 @@ std::optional resolveWeightIndex(mlir::Operation* weightOwner, mlir::V return std::nullopt; } -std::optional resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp) { - return resolveWeightIndex(weightOwner, vmmOp.getWeight()); -} - llvm::FailureOr resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge) { llvm::SmallVector viewOps; diff --git a/src/PIM/Common/IR/WeightUtils.hpp b/src/PIM/Common/IR/WeightUtils.hpp index be4cbb1..1298725 100644 --- a/src/PIM/Common/IR/WeightUtils.hpp +++ b/src/PIM/Common/IR/WeightUtils.hpp @@ -46,7 +46,6 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value); void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref callback); std::optional resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight); -std::optional resolveWeightIndex(mlir::Operation* weightOwner, pim::PimVMMOp vmmOp); llvm::FailureOr resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge = {}); diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.cpp index 745ac5c..0c24977 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.cpp @@ -40,14 +40,6 @@ static SmallVector normalizeAxesImpl(std::optional axesAttr, return normalizedAxes; } -SmallVector normalizeAxes(ArrayAttr axesAttr, int64_t rank) { - return normalizeAxesImpl(std::optional(axesAttr), rank); -} - -SmallVector normalizeAxes(std::optional axesAttr, int64_t rank) { - return normalizeAxesImpl(axesAttr, rank); -} - FailureOr> normalizeAxesChecked(std::optional axesAttr, int64_t rank) { SmallVector normalizedAxes = normalizeAxesImpl(axesAttr, rank); for (int64_t axis : normalizedAxes) @@ -56,11 +48,7 @@ FailureOr> normalizeAxesChecked(std::optional ax return normalizedAxes; } -FailureOr> normalizeAxesChecked(ArrayAttr axesAttr, int64_t rank) { - return normalizeAxesChecked(std::optional(axesAttr), rank); -} - -Value createAffineApplyOrConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) { +Value createAffineApplyOrFoldedConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) { AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp); @@ -68,22 +56,22 @@ Value createAffineApplyOrConstant(PatternRewriter& rewriter, Location loc, Affin Value multiplyIndexByConstant(PatternRewriter& rewriter, Operation* anchorOp, Value value, int64_t multiplier) { if (multiplier == 0) - return getOrCreateHostIndexConstant(rewriter, anchorOp, 0); + return getOrCreateIndexConstant(rewriter, anchorOp, 0); if (multiplier == 1) return value; MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApplyOrConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value}); + return createAffineApplyOrFoldedConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value}); } Value modIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) { if (divisor == 1) - return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApplyOrConstant(rewriter, loc, d0 % divisor, ValueRange {value}); + return createAffineApplyOrFoldedConstant(rewriter, loc, d0 % divisor, ValueRange {value}); } Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) { @@ -92,12 +80,12 @@ Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value val MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}); + return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}); } -Value getOrMaterializeIndexValue(PatternRewriter& rewriter, Location loc, OpFoldResult value) { +Value getOrMaterializeIndexValue(PatternRewriter& rewriter, OpFoldResult value) { if (auto attr = dyn_cast(value)) - return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast(attr).getInt()); + return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast(attr).getInt()); return cast(value); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.hpp index 3c20806..2d90496 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.hpp @@ -19,18 +19,12 @@ mlir::FailureOr normalizeAxisChecked(int64_t axis, int64_t rank); int64_t normalizeIndex(int64_t index, int64_t dimSize); -llvm::SmallVector normalizeAxes(mlir::ArrayAttr axesAttr, int64_t rank); - -llvm::SmallVector normalizeAxes(std::optional axesAttr, int64_t rank); - -mlir::FailureOr> normalizeAxesChecked(mlir::ArrayAttr axesAttr, int64_t rank); - mlir::FailureOr> normalizeAxesChecked(std::optional axesAttr, int64_t rank); -mlir::Value createAffineApplyOrConstant(mlir::PatternRewriter& rewriter, - mlir::Location loc, - mlir::AffineExpr expr, - mlir::ValueRange operands); +mlir::Value createAffineApplyOrFoldedConstant(mlir::PatternRewriter& rewriter, + mlir::Location loc, + mlir::AffineExpr expr, + mlir::ValueRange operands); mlir::Value multiplyIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Operation* anchorOp, mlir::Value value, int64_t multiplier); @@ -40,6 +34,6 @@ mlir::Value modIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location l mlir::Value floorDivIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor); -mlir::Value getOrMaterializeIndexValue(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::OpFoldResult value); +mlir::Value getOrMaterializeIndexValue(mlir::PatternRewriter& rewriter, mlir::OpFoldResult value); } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp index 54c283b..ae510a6 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp @@ -10,6 +10,7 @@ #include "ShapeTilingUtils.hpp" #include "IndexingUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" @@ -19,10 +20,6 @@ using namespace mlir; namespace onnx_mlir { -static Value getIndexValue(OpFoldResult result, ConversionPatternRewriter& rewriter, Location loc) { - return getOrMaterializeIndexValue(rewriter, loc, result); -} - static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) { APInt lhsConst; if (matchPattern(lhs, m_ConstantInt(&lhsConst)) && lhsConst.isZero()) @@ -43,11 +40,12 @@ static Value multiplyIndexValue(Value value, OpFoldResult factor, ConversionPatt return arith::MulIOp::create(rewriter, loc, value, cast(factor)).getResult(); if (factorConst.isZero()) - return arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); + return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); if (factorConst.isOne()) return value; - auto factorValue = arith::ConstantIndexOp::create(rewriter, loc, factorConst.getSExtValue()).getResult(); + auto factorValue = + getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), factorConst.getSExtValue()); return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult(); } @@ -61,8 +59,6 @@ int64_t getStaticShapeElementCount(ArrayRef shape) { return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies {}); } -int64_t getStaticShapeElementCount(RankedTensorType type) { return getStaticShapeElementCount(type.getShape()); } - SmallVector permuteShape(ArrayRef shape, ArrayRef permutation) { SmallVector permutedShape; permutedShape.reserve(permutation.size()); @@ -226,49 +222,6 @@ sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewri return slicesPerCore; } -DenseMap>> tileMatrix( - Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc) { - assert("Not a matrix" && isMatrixShape(getTensorShape(matrixToTile))); - - DenseMap>> tiles; - - SmallVector hSlices = sliceTensor(matrixToTile, 1, hSliceSize, rewriter, loc); - size_t numHSlices = hSlices.size(); - for (size_t hSliceId = 0; hSliceId < numHSlices; hSliceId++) { - Value hSlice = hSlices[hSliceId]; - SmallVector vSlices = sliceTensor(hSlice, 0, vSliceSize, rewriter, loc); - for (size_t vSliceId = 0; vSliceId < vSlices.size(); vSliceId++) { - size_t coreId = vSliceId / crossbarCountInCore; - Value vSlice = vSlices[vSliceId]; - tiles[hSliceId][coreId].push_back(vSlice); - } - } - return tiles; -} - -Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) { - auto oldType = cast(scalarToBroadcast.getType()); - Type elementType = oldType.getElementType(); - int64_t shape[2] = {1, length}; - Type type = oldType.cloneWith(ArrayRef(shape), elementType); - - auto buildBroadcast = [&](Value input) -> Value { - auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); - SmallVector index(oldType.getRank(), zero); - auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult(); - return tensor::SplatOp::create(rewriter, loc, type, elementValue); - }; - - if (isCompileTimeComputable(scalarToBroadcast)) - return buildBroadcast(scalarToBroadcast); - - auto broadcastCompute = - createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) { - spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input)); - }); - return broadcastCompute.getResult(0); -} - Value materializeContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef offsets, @@ -294,7 +247,7 @@ Value materializeContiguousTensorSlice(Value source, Value init = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult(); SmallVector zeroIndices(resultType.getRank()); for (Value& zeroIndex : zeroIndices) - zeroIndex = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); + zeroIndex = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); SmallVector resultIndices; resultIndices.reserve(resultType.getRank()); @@ -304,7 +257,7 @@ Value materializeContiguousTensorSlice(Value source, SmallVector sourceIndices; sourceIndices.reserve(resultType.getRank()); for (unsigned idx = 0; idx < resultType.getRank(); ++idx) { - Value offsetValue = getIndexValue(offsets[idx], rewriter, loc); + Value offsetValue = getOrMaterializeIndexValue(rewriter, offsets[idx]); Value scaledIndex = multiplyIndexValue(resultIndices[idx], strides[idx], rewriter, loc); sourceIndices.push_back(addIndexValues(offsetValue, scaledIndex, rewriter, loc)); } @@ -337,8 +290,8 @@ Value materializeContiguousTensorSlice(Value source, } Value lower = zeroIndices[dim]; - Value upper = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(dim)).getResult(); - Value step = arith::ConstantIndexOp::create(rewriter, loc, 1).getResult(); + Value upper = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultType.getDimSize(dim)); + Value step = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); auto loop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {accumulator}); rewriter.setInsertionPointToStart(loop.getBody()); resultIndices.push_back(loop.getInductionVar()); @@ -352,17 +305,6 @@ Value materializeContiguousTensorSlice(Value source, return buildLoopNest(buildLoopNest, 0, init); } -Value extractStaticSlice(PatternRewriter& rewriter, - Location loc, - Value source, - RankedTensorType resultType, - ArrayRef offsets) { - return tensor::ExtractSliceOp::create( - rewriter, loc, resultType, source, offsets, getStaticSizes(rewriter, resultType.getShape()), - getUnitStrides(rewriter, resultType.getRank())) - .getResult(); -} - Value extractAxisSlice( PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) { auto sourceType = cast(source.getType()); diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp index 3e90301..f0367bd 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp @@ -18,41 +18,6 @@ namespace onnx_mlir { -template -inline auto getImageWidth(const ShapedType& shapedType) { - return shapedType.getDimSize(2); -} - -template -inline auto getImageHeight(const ShapedType& shapedType) { - return shapedType.getDimSize(3); -} - -template -inline auto getImageChannel(const ShapedType& shapedType) { - return shapedType.getDimSize(1); -} - -template -inline auto getImageN(const ShapedType& shapedType) { - return shapedType.getDimSize(0); -} - -template -inline auto getKernelWidth(const ShapedType& shapedType) { - return shapedType.getDimSize(2); -} - -template -inline auto getKernelHeight(const ShapedType& shapedType) { - return shapedType.getDimSize(3); -} - -template -inline auto getFilterCount(const ShapedType& shapedType) { - return shapedType.getDimSize(0); -} - using HSliceId = size_t; using CoreId = size_t; @@ -89,17 +54,6 @@ bool isHVectorShape(mlir::ArrayRef shape) { return shape.size() == 2 && shape[0] == 1; } -template -bool isVVectorShape(mlir::ArrayRef shape) { - return shape.size() == 2 && shape[1] == 1; -} - -template -T getVectorLength(mlir::ArrayRef shape) { - assert(isVectorShape(shape)); - return shape[0] != 1 ? shape[0] : shape[1]; -} - inline auto getTensorShape(mlir::Value tensor) { return mlir::cast(tensor.getType()).getShape(); } @@ -117,8 +71,6 @@ bool hasStaticPositiveShape(mlir::RankedTensorType type); int64_t getStaticShapeElementCount(mlir::ArrayRef shape); -int64_t getStaticShapeElementCount(mlir::RankedTensorType type); - llvm::SmallVector permuteShape(mlir::ArrayRef shape, mlir::ArrayRef permutation); llvm::SmallVector invertPermutation(mlir::ArrayRef permutation); @@ -156,20 +108,6 @@ llvm::SmallVector sliceVector(const mlir::Value& vectorToSlice, llvm::DenseMap> sliceVectorPerCrossbarPerCore( const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); -/// Tiles a matrix first across output columns and then across input rows so it -/// can be assigned to crossbars grouped by core. -llvm::DenseMap>> -tileMatrix(mlir::Value& matrixToTile, - int64_t hSliceSize, - int64_t vSliceSize, - mlir::ConversionPatternRewriter& rewriter, - mlir::Location& loc); - -mlir::Value broadcastToVector(mlir::Value scalarToBroadcast, - int64_t length, - mlir::ConversionPatternRewriter& rewriter, - mlir::Location loc); - mlir::Value materializeContiguousTensorSlice(mlir::Value source, mlir::RankedTensorType resultType, llvm::ArrayRef offsets, @@ -177,12 +115,6 @@ mlir::Value materializeContiguousTensorSlice(mlir::Value source, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); -mlir::Value extractStaticSlice(mlir::PatternRewriter& rewriter, - mlir::Location loc, - mlir::Value source, - mlir::RankedTensorType resultType, - llvm::ArrayRef offsets); - mlir::Value extractAxisSlice(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, diff --git a/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp b/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp index c947a2a..0a9567a 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp @@ -8,10 +8,10 @@ #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/ErrorHandling.h" #include +#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -38,13 +38,6 @@ static bool isStaticTensorResult(Operation* op) { }); } -static SmallVector computeRowMajorStrides(ArrayRef shape) { - SmallVector strides(shape.size(), 1); - for (int64_t dim = static_cast(shape.size()) - 2; dim >= 0; --dim) - strides[dim] = strides[dim + 1] * shape[dim + 1]; - return strides; -} - static FailureOr transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef perms) { auto tensorType = dyn_cast(denseAttr.getType()); if (!tensorType) diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 4a15427..ca04343 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -61,9 +61,9 @@ static Value createPaddedRows(Value tensorValue, padBlock->addArgument(rewriter.getIndexType(), loc); padOp.getRegion().push_back(padBlock); rewriter.setInsertionPointToStart(padBlock); - auto zero = arith::ConstantOp::create( - rewriter, loc, tensorType.getElementType(), rewriter.getZeroAttr(tensorType.getElementType())); - tensor::YieldOp::create(rewriter, loc, zero.getResult()); + auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()), + tensorType.getElementType()); + tensor::YieldOp::create(rewriter, loc, zero); rewriter.setInsertionPointAfter(padOp); return padOp.getResult(); } @@ -106,7 +106,7 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr, } auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues); - return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType); } static Value createConvWeightMatrix(Value w, @@ -158,7 +158,7 @@ static Value buildPackedBias(bool hasBias, auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType()); auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues); - return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult(); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedBiasAttr, packedBiasType); } static Value createIm2colRowComputes(Value x, @@ -214,8 +214,8 @@ static Value createIm2colRowComputes(Value x, padBlock->addArgument(rewriter.getIndexType(), loc); padOp.getRegion().push_back(padBlock); rewriter.setInsertionPointToStart(padBlock); - auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0)); - tensor::YieldOp::create(rewriter, loc, zero.getResult()); + auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getFloatAttr(elemType, 0.0), elemType); + tensor::YieldOp::create(rewriter, loc, zero); rewriter.setInsertionPointAfter(padOp); paddedInput = padOp.getResult(); } @@ -223,13 +223,14 @@ static Value createIm2colRowComputes(Value x, // Build im2col [numPatches, patchSize] incrementally to keep the IR small // until the late PIM unrolling step. Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType); - auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); - auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); - auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches); - auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch); - auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth); - auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight); - auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + auto c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); + auto c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); + auto cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, numPatches); + auto cNumPatchesPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, numPatchesPerBatch); + auto cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, outWidth); + auto cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight); + auto cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth); auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit}); rewriter.setInsertionPointToStart(im2colLoop.getBody()); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp index 4615c76..2912e49 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp @@ -83,7 +83,7 @@ static FailureOr materializeBroadcastedConstantTensor(Value value, } auto broadcastedAttr = DenseElementsAttr::get(resultType, resultValues); - return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult(); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), broadcastedAttr, resultType); } static FailureOr @@ -121,7 +121,7 @@ static FailureOr materializeReciprocalTensor(Value value, } auto reciprocalAttr = DenseFPElementsAttr::get(resultType, reciprocalValues); - return arith::ConstantOp::create(rewriter, loc, resultType, reciprocalAttr).getResult(); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), reciprocalAttr, resultType); } template diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index 5c13ea4..78bd79b 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -50,38 +50,17 @@ materializeScaledConstantTensor(Value value, float factor, ConversionPatternRewr return failure(); auto scaledAttr = DenseFPElementsAttr::get(cast(denseAttr.getType()), scaledValues); - return arith::ConstantOp::create(rewriter, loc, denseAttr.getType(), scaledAttr).getResult(); -} - -static Value transposeForSpatial(Value value, - RankedTensorType resultType, - ArrayRef permutation, - ConversionPatternRewriter& rewriter, - Location loc) { - return transposeMaybeInCompute(value, resultType, permutation, rewriter, loc); -} - -static Value -multiplyIndexByConstant(Value value, int64_t multiplier, ConversionPatternRewriter& rewriter, Location loc) { - return onnx_mlir::multiplyIndexByConstant(rewriter, value.getDefiningOp(), value, multiplier); -} - -static Value modIndexByConstant(Value value, int64_t divisor, ConversionPatternRewriter& rewriter, Location loc) { - return onnx_mlir::modIndexByConstant(rewriter, loc, value, divisor); -} - -static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) { - return modIndexByConstant(lane, numOutRows, rewriter, loc); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scaledAttr, denseAttr.getType()); } static Value createGemmBatchKOffset( Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) { if (numKSlices == 1) - return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApplyOrConstant( + return createAffineApplyOrFoldedConstant( rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane}); } @@ -92,11 +71,11 @@ static Value createGemmBatchHOffset(Value lane, ConversionPatternRewriter& rewriter, Location loc) { if (numOutHSlices == 1) - return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApplyOrConstant( + return createAffineApplyOrFoldedConstant( rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane}); } @@ -115,9 +94,9 @@ createZeroPaddedTensor(Value value, RankedTensorType resultType, ConversionPatte padBlock->addArgument(rewriter.getIndexType(), loc); padOp.getRegion().push_back(padBlock); rewriter.setInsertionPointToStart(padBlock); - auto zero = arith::ConstantOp::create( - rewriter, loc, sourceType.getElementType(), rewriter.getZeroAttr(sourceType.getElementType())); - tensor::YieldOp::create(rewriter, loc, zero.getResult()); + auto zero = getOrCreateConstant( + rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType()); + tensor::YieldOp::create(rewriter, loc, zero); rewriter.setInsertionPointAfter(padOp); return padOp.getResult(); } @@ -149,7 +128,7 @@ static FailureOr materializePaddedConstantMatrix(Value value, resultValues[row * resultShape[1] + col] = sourceValues[row * sourceShape[1] + col]; auto resultAttr = DenseElementsAttr::get(resultType, resultValues); - return arith::ConstantOp::create(rewriter, loc, resultType, resultAttr).getResult(); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType); } static FailureOr materializePaddedBroadcastedConstantTensor(Value value, @@ -215,7 +194,7 @@ static FailureOr materializePaddedBroadcastedConstantTensor(Value value, } auto resultAttr = DenseElementsAttr::get(resultType, resultValues); - return arith::ConstantOp::create(rewriter, loc, resultType, resultAttr).getResult(); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType); } static FailureOr prepareBias(Value c, @@ -274,7 +253,7 @@ static spatial::SpatComputeBatch createVmmBatch(Value a, const int64_t laneCount = partialPiecesType.getDimSize(0); auto batchOp = createSpatComputeBatch( rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) { - Value row = createGemmBatchRow(args.lane, numOutRows, rewriter, loc); + Value row = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutRows); Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc); Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc); @@ -312,12 +291,7 @@ static Value createDynamicGemmBatchRow( MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}); -} - -static Value createDynamicGemmBatchColumn( - Value lane, int64_t numOutCols, ConversionPatternRewriter& rewriter, Location loc) { - return modIndexByConstant(lane, numOutCols, rewriter, loc); + return createAffineApplyOrFoldedConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}); } static Value @@ -385,7 +359,7 @@ static Value createScalarTensorConstant(RankedTensorType scalarType, auto elementType = scalarType.getElementType(); auto scalarAttr = rewriter.getFloatAttr(elementType, value); auto denseAttr = DenseElementsAttr::get(scalarType, scalarAttr); - return arith::ConstantOp::create(rewriter, loc, scalarType, denseAttr).getResult(); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), denseAttr, scalarType); } static Value createBroadcastedBiasScalar(Value bias, @@ -435,7 +409,7 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a, auto batchOp = createSpatComputeBatch( rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) { Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc); - Value column = createDynamicGemmBatchColumn(args.lane, numOutCols, rewriter, loc); + Value column = onnx_mlir::modIndexByConstant(rewriter, loc, args.lane, numOutCols); auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType()); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); @@ -475,16 +449,16 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces, Value biasArg = bias ? blockArgs[1] : Value(); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult(); - Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); - Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); - Value cLaneCount = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount); + Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); + Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount); auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit}); rewriter.setInsertionPointToStart(loop.getBody()); Value lane = loop.getInductionVar(); Value outputAcc = loop.getRegionIterArgs().front(); Value row = createDynamicGemmBatchRow(lane, numOutCols, rewriter, loc); - Value column = createDynamicGemmBatchColumn(lane, numOutCols, rewriter, loc); + Value column = onnx_mlir::modIndexByConstant(rewriter, loc, lane, numOutCols); SmallVector scalarOffsets {lane, rewriter.getIndexAttr(0)}; SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; @@ -522,7 +496,7 @@ static Value createPartialGroupOffset(Value hSlice, Location loc) { MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApplyOrConstant( + return createAffineApplyOrFoldedConstant( rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice}); } @@ -604,7 +578,9 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces, auto buildOutputSlice = [&](Value outputAcc, Value hSlice) -> Value { Value reduced = reducePartialPiecesForHSlice(partialPiecesArg, hSlice, pieceType, numKSlices, numOutRows, rewriter, loc); - Value hOffset = multiplyIndexByConstant(hSlice, crossbarSize.getValue(), rewriter, loc); + Value hOffset = + onnx_mlir::multiplyIndexByConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), hSlice, + crossbarSize.getValue()); if (biasArg) { SmallVector biasOffsets {rewriter.getIndexAttr(0), hOffset}; Value biasSlice = @@ -620,13 +596,14 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces, Value paddedOutput = outputInit; if (numOutHSlices == 1) { - Value hSlice = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + Value hSlice = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); paddedOutput = buildOutputSlice(outputInit, hSlice); } else { - Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); - Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); - Value cOutHSlices = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices); + Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); + Value cOutHSlices = + getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices); auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit}); rewriter.setInsertionPointToStart(hLoop.getBody()); @@ -763,7 +740,7 @@ LogicalResult GemmToSpatialComputes::matchAndRewrite(ONNXGemmOp gemmOp, if (gemmOpAdaptor.getTransB()) { auto bShape = bType.getShape(); auto transposedType = RankedTensorType::get({bShape[1], bShape[0]}, bType.getElementType()); - b = transposeForSpatial(b, transposedType, {1, 0}, rewriter, loc); + b = transposeMaybeInCompute(b, transposedType, {1, 0}, rewriter, loc); bType = cast(b.getType()); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp index 5181aa4..dfd8535 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp @@ -76,7 +76,7 @@ static Value computeLaneIndex(Value lane, ConversionPatternRewriter& rewriter, Location loc) { if (dimSize == 1) - return arith::ConstantIndexOp::create(rewriter, loc, 0); + return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); @@ -85,7 +85,7 @@ static Value computeLaneIndex(Value lane, expr = expr.floorDiv(stride); if (dimSize != 1) expr = expr % dimSize; - return createAffineApplyOrConstant(rewriter, loc, expr, ValueRange {lane}); + return createAffineApplyOrFoldedConstant(rewriter, loc, expr, ValueRange {lane}); } static FailureOr buildReduceMeanKeepdimsBatch(Value input, @@ -236,7 +236,7 @@ static Value squeezeReducedAxes(Value keepdimsValue, Location loc) { if (resultType.getRank() == 0) { SmallVector indices(cast(keepdimsValue.getType()).getRank(), - arith::ConstantIndexOp::create(rewriter, loc, 0)); + getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0)); Value element = tensor::ExtractOp::create(rewriter, loc, keepdimsValue, indices); return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element}); } @@ -268,7 +268,7 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern { return success(); } - auto axes = normalizeAxesChecked(reduceMeanOp.getAxesAttr(), inputType.getRank()); + auto axes = normalizeAxesChecked(std::optional(reduceMeanOp.getAxesAttr()), inputType.getRank()); if (failed(axes)) return failure(); SmallVector reducedAxes = buildReducedAxesMask(*axes, inputType.getRank()); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp index 80ecda4..7396daf 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp @@ -31,17 +31,18 @@ static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Loca static Value createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) { + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); if (!useMinimumValue) - return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType)); + return getOrCreateConstant(rewriter, anchorOp, rewriter.getZeroAttr(elementType), elementType); if (auto floatType = dyn_cast(elementType)) { auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true); - return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue)); + return getOrCreateConstant(rewriter, anchorOp, rewriter.getFloatAttr(floatType, minValue), elementType); } if (auto integerType = dyn_cast(elementType)) { auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth()); - return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue)); + return getOrCreateConstant(rewriter, anchorOp, rewriter.getIntegerAttr(integerType, minValue), elementType); } llvm_unreachable("unsupported pool element type"); @@ -148,7 +149,7 @@ static FailureOr createAverageScaleTensor(ConversionPatternRewriter& rewr } auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues); - return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult(); + return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scaleAttr, scaleType); } template @@ -265,13 +266,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern { createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight); Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()); - Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); - Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); - Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount); - Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth); - Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth); - Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight); - Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); + Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); + Value cOutputPatchCount = getOrCreateIndexConstant(rewriter, anchorOp, outputPatchCount); + Value cOutputPixelsPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, outputHeight * outputWidth); + Value cOutputWidth = getOrCreateIndexConstant(rewriter, anchorOp, outputWidth); + Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight); + Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth); auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit}); rewriter.setInsertionPointToStart(outputLoop.getBody()); @@ -296,14 +298,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern { for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { Value paddedInH = windowBaseH; if (kernelH * dilationHeight != 0) { - Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight); + Value kernelHOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelH * dilationHeight); paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset); } for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { Value paddedInW = windowBaseW; if (kernelW * dilationWidth != 0) { - Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth); + Value kernelWOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelW * dilationWidth); paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset); } diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp index bc95983..3e9f3ee 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp @@ -52,9 +52,10 @@ static Value buildLoopSoftmaxNest(Value input, if (axis == inputType.getRank() - 1) return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc); - Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); - Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); - Value cUpper = arith::ConstantIndexOp::create(rewriter, loc, inputType.getDimSize(axis)); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); + Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); + Value cUpper = getOrCreateIndexConstant(rewriter, anchorOp, inputType.getDimSize(axis)); auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator}); rewriter.setInsertionPointToStart(loop.getBody()); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp index 56f1625..b8f8dbf 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp @@ -17,9 +17,10 @@ namespace { static Value buildNearestAsymmetricIndex( Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) { - Value cInputDim = arith::ConstantIndexOp::create(rewriter, loc, inputDim); - Value cOutputDim = arith::ConstantIndexOp::create(rewriter, loc, outputDim); - Value cInputDimLast = arith::ConstantIndexOp::create(rewriter, loc, inputDim - 1); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value cInputDim = getOrCreateIndexConstant(rewriter, anchorOp, inputDim); + Value cOutputDim = getOrCreateIndexConstant(rewriter, anchorOp, outputDim); + Value cInputDimLast = getOrCreateIndexConstant(rewriter, anchorOp, inputDim - 1); Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim); Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim); return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast); @@ -37,12 +38,13 @@ static Value buildNearestResizeLoop(Value input, SmallVector unitSizes(resultType.getRank(), rewriter.getIndexAttr(1)); SmallVector unitStrides(resultType.getRank(), rewriter.getIndexAttr(1)); - Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); - Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); - Value cOutputN = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(0)); - Value cOutputC = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(1)); - Value cOutputH = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(2)); - Value cOutputW = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(3)); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0); + Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1); + Value cOutputN = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(0)); + Value cOutputC = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(1)); + Value cOutputH = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(2)); + Value cOutputW = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(3)); Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Transpose.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Transpose.cpp index 7085589..c928093 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Transpose.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Transpose.cpp @@ -1,9 +1,11 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" +#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -30,6 +32,54 @@ static Value createTransposeInit(Value input, return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult(); } +static FailureOr materializeTransposedConstant(Value input, + RankedTensorType resultType, + ArrayRef permutation, + ConversionPatternRewriter& rewriter, + Location loc) { + auto denseAttr = getHostConstDenseElementsAttr(input); + if (!denseAttr) + return failure(); + + auto inputType = dyn_cast(denseAttr.getType()); + if (!inputType || !inputType.hasStaticShape() || !resultType.hasStaticShape() + || inputType.getRank() != resultType.getRank() + || static_cast(permutation.size()) != inputType.getRank()) { + return failure(); + } + + if (denseAttr.isSplat()) + return getOrCreateConstant(rewriter, + rewriter.getInsertionBlock()->getParentOp(), + DenseElementsAttr::get(resultType, denseAttr.getSplatValue()), + resultType); + + SmallVector inputValues(denseAttr.getValues()); + SmallVector resultValues(inputValues.size()); + SmallVector inputStrides = computeRowMajorStrides(inputType.getShape()); + SmallVector resultStrides = computeRowMajorStrides(resultType.getShape()); + SmallVector inputIndices(inputType.getRank(), 0); + + for (auto [linearIndex, value] : llvm::enumerate(inputValues)) { + int64_t remaining = static_cast(linearIndex); + for (int64_t dim = 0; dim < inputType.getRank(); ++dim) { + inputIndices[dim] = inputStrides.empty() ? 0 : remaining / inputStrides[dim]; + remaining = inputStrides.empty() ? 0 : remaining % inputStrides[dim]; + } + + int64_t resultLinearIndex = 0; + for (int64_t dim = 0; dim < resultType.getRank(); ++dim) + resultLinearIndex += inputIndices[permutation[dim]] * resultStrides[dim]; + + resultValues[resultLinearIndex] = value; + } + + return getOrCreateConstant(rewriter, + rewriter.getInsertionBlock()->getParentOp(), + DenseElementsAttr::get(resultType, resultValues), + resultType); +} + struct TransposeToLinalgTranspose : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -44,6 +94,14 @@ struct TransposeToLinalgTranspose : OpConversionPattern { auto permutation = getTransposePermutationChecked(transposeOp.getPermAttr(), inputType.getRank()); if (failed(permutation)) return failure(); + if (isCompileTimeComputable(adaptor.getData())) { + auto constantTranspose = + materializeTransposedConstant(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc()); + if (succeeded(constantTranspose)) { + rewriter.replaceOp(transposeOp, *constantTranspose); + return success(); + } + } Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc()); Value transposed = linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation) diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index 7c071bc..bd2e9e2 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -7,6 +7,7 @@ #include "mlir/IR/Matchers.h" #include "Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" @@ -18,15 +19,9 @@ using namespace onnx_mlir::pim; namespace onnx_mlir { namespace { -static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { - if (isa(op)) - return operandIndex == 2; - return false; -} - static bool isUsedOnlyAsExplicitHostOperand(Value value) { return !value.use_empty() && llvm::all_of(value.getUses(), [](OpOperand& use) { - return isExplicitHostOperand(use.getOwner(), use.getOperandNumber()); + return isExplicitDevToHostTargetOperand(use.getOwner(), use.getOperandNumber()); }); } @@ -55,7 +50,7 @@ static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value ba if (scale == 1) return base; - auto scaleValue = arith::ConstantIndexOp::create(rewriter, loc, scale).getResult(); + auto scaleValue = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scale); return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult(); } @@ -77,7 +72,8 @@ static Value createHostTargetOffset(IRRewriter& rewriter, if (auto attr = dyn_cast(offset)) { auto intAttr = dyn_cast(attr); assert(intAttr && "expected integer offset attribute"); - scaledOffset = arith::ConstantIndexOp::create(rewriter, loc, intAttr.getInt() * scale).getResult(); + scaledOffset = + getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), intAttr.getInt() * scale); } else { scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast(offset)), scale); @@ -88,7 +84,7 @@ static Value createHostTargetOffset(IRRewriter& rewriter, } if (!totalOffset) - totalOffset = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); + totalOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); return totalOffset; } @@ -214,7 +210,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc()); auto hostTargetType = cast(hostTarget.getType()); Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper); - Value zeroOffset = arith::ConstantIndexOp::create(rewriter, insertSlice.getLoc(), 0).getResult(); + Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0); pim::PimMemCopyDevToHostOp::create(rewriter, insertSlice.getLoc(), hostTarget.getType(), @@ -254,7 +250,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) { if (!isa(operand.getType()) || mapper.contains(operand)) continue; - if (isExplicitHostOperand(&op, operandIndex)) + if (isExplicitDevToHostTargetOperand(&op, operandIndex)) continue; Operation* definingOp = operand.getDefiningOp(); diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp index a8c0fa5..5388dfa 100644 --- a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp @@ -40,7 +40,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite continue; if (auto constantOp = dyn_cast(definingOp)) { - mapping.map(operand, getOrCreateHostConstantLike(constantFolder, constantOp)); + mapping.map(operand, getOrCreateConstantLike(constantFolder, constantOp)); continue; } @@ -218,7 +218,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp continue; if (auto constantOp = input.getDefiningOp()) { - blockArg->replaceAllUsesWith(getOrCreateHostConstantLike(constantFolder, constantOp)); + blockArg->replaceAllUsesWith(getOrCreateConstantLike(constantFolder, constantOp)); continue; } @@ -230,8 +230,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp PimMemCopyHostToDevOp::create(rewriter, loc, outputBuffer.getType(), - getOrCreateHostIndexConstant(constantFolder, outputBuffer.getOperation(), 0), - getOrCreateHostIndexConstant(constantFolder, outputBuffer.getOperation(), 0), + getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0), + getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0), outputBuffer, input, getTensorSizeInBytesAttr(rewriter, input)) diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.cpp b/src/PIM/Conversion/SpatialToPim/Patterns.cpp index e619a0f..8a98192 100644 --- a/src/PIM/Conversion/SpatialToPim/Patterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns.cpp @@ -16,25 +16,9 @@ void populateInitialPatterns(RewritePatternSet& patterns) { populateTransposeLoweringPatterns(patterns); } -void populateGlobalTensorMaterializationPatternPhase(RewritePatternSet& patterns) { - populateGlobalTensorMaterializationPatterns(patterns); -} - -void populateInitialTensorPackingPatterns(RewritePatternSet& patterns) { - populateTensorPackingPatterns(patterns); -} - void populateCoreBodyPatterns(RewritePatternSet& patterns) { raptor::populateWithGenerated(patterns); populateTransposeLoweringPatterns(patterns); } -void populateFinalTensorPackingPatterns(RewritePatternSet& patterns) { - populateTensorPackingPatterns(patterns); -} - -void populateCommunicationPatterns(RewritePatternSet& patterns) { - populateChannelLoweringPatterns(patterns); -} - } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.hpp b/src/PIM/Conversion/SpatialToPim/Patterns.hpp index c1a7bad..e9b656f 100644 --- a/src/PIM/Conversion/SpatialToPim/Patterns.hpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns.hpp @@ -9,11 +9,7 @@ namespace onnx_mlir { void populateInitialPatterns(mlir::RewritePatternSet& patterns); -void populateGlobalTensorMaterializationPatternPhase(mlir::RewritePatternSet& patterns); -void populateInitialTensorPackingPatterns(mlir::RewritePatternSet& patterns); void populateCoreBodyPatterns(mlir::RewritePatternSet& patterns); -void populateFinalTensorPackingPatterns(mlir::RewritePatternSet& patterns); -void populateCommunicationPatterns(mlir::RewritePatternSet& patterns); void populateTransposeLoweringPatterns(mlir::RewritePatternSet& patterns); void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns); diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp index 18a5e1d..32dcea3 100644 --- a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp +++ b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp @@ -326,7 +326,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite continue; if (auto constantOp = dyn_cast(definingOp)) { - mapping.map(operand, getOrCreateHostConstantLike(constantFolder, constantOp)); + mapping.map(operand, getOrCreateConstantLike(constantFolder, constantOp)); continue; } @@ -370,8 +370,8 @@ static Value emitHostCopy(IRRewriter& rewriter, OperationFolder& constantFolder) { Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp(); assert(anchorOp && "expected a concrete op anchor for return-path host copy constants"); - Value hostTargetOffsetValue = getOrCreateHostIndexConstant(constantFolder, anchorOp, hostTargetOffset); - Value deviceSourceOffsetValue = getOrCreateHostIndexConstant(constantFolder, anchorOp, deviceSourceOffset); + Value hostTargetOffsetValue = getOrCreateIndexConstant(constantFolder, anchorOp, hostTargetOffset); + Value deviceSourceOffsetValue = getOrCreateIndexConstant(constantFolder, anchorOp, deviceSourceOffset); return PimMemCopyDevToHostOp::create(rewriter, loc, outputTensor.getType(), diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 8022cc4..56bda03 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -81,7 +81,7 @@ static Value createZeroedDeviceHVector(IRRewriter& rewriter, auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType); auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType); auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName()); - auto zeroIndex = getOrCreateHostIndexConstant(constantFolder, outputBuffer.getOperation(), 0); + auto zeroIndex = getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0); auto sizeAttr = rewriter.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(tensorType))); if (outputBuffer->getParentOfType()) @@ -160,7 +160,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { } RewritePatternSet globalTensorPatterns(ctx); - populateGlobalTensorMaterializationPatternPhase(globalTensorPatterns); + populateGlobalTensorMaterializationPatterns(globalTensorPatterns); walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns)); auto returnOp = cast(funcOp.front().getTerminator()); @@ -190,7 +190,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { } RewritePatternSet initialTensorPackingPatterns(ctx); - populateInitialTensorPackingPatterns(initialTensorPackingPatterns); + populateTensorPackingPatterns(initialTensorPackingPatterns); walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns)); eraseUnusedTensorPackingOps(funcOp, rewriter); @@ -250,7 +250,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { eraseOpsToRemove(); RewritePatternSet finalTensorPackingPatterns(ctx); - populateFinalTensorPackingPatterns(finalTensorPackingPatterns); + populateTensorPackingPatterns(finalTensorPackingPatterns); walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns)); eraseUnusedTensorPackingOps(funcOp, rewriter); @@ -270,7 +270,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { spatial::SpatExtractRowsOp>(); RewritePatternSet communicationPatterns(ctx); - populateCommunicationPatterns(communicationPatterns); + populateChannelLoweringPatterns(communicationPatterns); if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) { funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops"); signalPassFailure(); @@ -333,8 +333,8 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables( rewriter, loc, tensorType, - getOrCreateHostIndexConstant(constantFolder, deviceTensor.getOperation(), 0), - getOrCreateHostIndexConstant(constantFolder, + getOrCreateIndexConstant(constantFolder, deviceTensor.getOperation(), 0), + getOrCreateIndexConstant(constantFolder, deviceTensor.getOperation(), static_cast(elementsOffset * elementByteSize) ), deviceTensor, inputTensor, diff --git a/src/PIM/Dialect/Pim/PimOpsVerify.cpp b/src/PIM/Dialect/Pim/PimOpsVerify.cpp index 168f0d8..c29eda3 100644 --- a/src/PIM/Dialect/Pim/PimOpsVerify.cpp +++ b/src/PIM/Dialect/Pim/PimOpsVerify.cpp @@ -9,6 +9,7 @@ #include "llvm/Support/LogicalResult.h" #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" +#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -19,16 +20,6 @@ namespace pim { namespace { -static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { - if (isa(op)) - return operandIndex == 3; - if (isa(op)) - return operandIndex == 1; - if (isa(op)) - return operandIndex == 2; - return false; -} - static Region* getParentRegion(Value value) { if (auto blockArgument = dyn_cast(value)) return blockArgument.getParentRegion(); @@ -63,7 +54,7 @@ static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region for (OpOperand& operand : op->getOpOperands()) { Value value = operand.get(); if (isDefinedInsideRegion(value, region) || isConstantExternalValue(value) - || isExplicitHostOperand(op, operand.getOperandNumber())) + || isExplicitHostMemCopyOperand(op, operand.getOperandNumber())) continue; InFlightDiagnostic diagnostic = ownerOp->emitOpError() diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index bda321d..80e9517 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -618,10 +618,6 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali llvm_unreachable("Cannot reach here"); } -Value createIndexConstant(MaterializerState& state, Operation* anchor, int64_t value) { - return getOrCreateHostIndexConstant(state.constantFolder, anchor, value); -} - // ----------------------------------------------------------------------------- // Tensor packing helpers. // ----------------------------------------------------------------------------- @@ -681,7 +677,7 @@ Value scaleIndexByDim0Size(MaterializerState& state, Operation* anchor, Value in if (dim0Size == 1) return index; - Value dim0SizeValue = createIndexConstant(state, anchor, dim0Size); + Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, anchor, dim0Size); return arith::MulIOp::create(state.rewriter, loc, index, dim0SizeValue).getResult(); } @@ -731,7 +727,7 @@ std::optional extractPackedProducerSlice(MaterializerState& state, state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); - Value firstOffset = createIndexConstant(state, materializedClass.op, rowOffset); + Value firstOffset = getOrCreateIndexConstant(state.constantFolder, materializedClass.op, rowOffset); return createDim0ExtractSlice(state, materializedClass.op->getLoc(), packed, firstOffset, rowCount); } @@ -754,7 +750,7 @@ Value getPackedSliceForRunIndex(MaterializerState& state, size_t index, Location loc) { int64_t rowOffset = static_cast(index) * fragmentType.getDimSize(0); - Value firstOffset = createIndexConstant(state, anchor, rowOffset); + Value firstOffset = getOrCreateIndexConstant(state.constantFolder, anchor, rowOffset); return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0)); } @@ -939,7 +935,7 @@ Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, Arr auto type = RankedTensorType::get({static_cast(values.size())}, state.rewriter.getIndexType()); auto attr = DenseIntElementsAttr::get(type, elements); - return getOrCreateHostConstant(state.constantFolder, anchor, attr, type); + return getOrCreateConstant(state.constantFolder, anchor, attr, type); } bool allEqual(ArrayRef values) { @@ -1041,7 +1037,7 @@ Value createIndexedIndexValue( assert(!values.empty() && "expected at least one indexed value"); if (allEqual(values)) - return createIndexConstant(state, anchor, values.front()); + return getOrCreateIndexConstant(state.constantFolder, anchor, values.front()); if (std::optional pattern = getIndexedIndexPattern(values)) return createAffineIndexValue(state, *pattern, index, loc); @@ -1110,7 +1106,7 @@ Value createOriginalLaneValue(MaterializerState& state, Location loc) { assert(!peers.empty() && "expected at least one peer instance"); if (!materializedClass.isBatch) - return createIndexConstant(state, materializedClass.op, peers.front().laneStart); + return getOrCreateIndexConstant(state.constantFolder, materializedClass.op, peers.front().laneStart); auto batch = cast(materializedClass.op); auto laneArg = batch.getLaneArgument(); @@ -1465,9 +1461,9 @@ void appendScalarSend(MaterializerState& state, assert(!sourceClass.isBatch && "scalar send helper expects a scalar source class"); state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - Value channelIdValue = createIndexConstant(state, sourceClass.op, channelId); - Value sourceCoreIdValue = createIndexConstant(state, sourceClass.op, sourceCoreId); - Value targetCoreIdValue = createIndexConstant(state, sourceClass.op, targetCoreId); + Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, channelId); + Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, sourceCoreId); + Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, targetCoreId); SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload); } @@ -1485,9 +1481,9 @@ void appendScalarSendLoop(MaterializerState& state, state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - Value lowerBound = createIndexConstant(state, sourceClass.op, 0); - Value upperBound = createIndexConstant(state, sourceClass.op, static_cast(channelIds.size())); - Value step = createIndexConstant(state, sourceClass.op, 1); + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(channelIds.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); @@ -1514,9 +1510,9 @@ Value buildProjectedPackedPayload(MaterializerState& state, state.rewriter, loc, descriptor.payloadType.getShape(), descriptor.payloadType.getElementType()) .getResult(); - Value lowerBound = createIndexConstant(state, anchor, 0); - Value upperBound = createIndexConstant(state, anchor, descriptor.fragmentsPerLane); - Value step = createIndexConstant(state, anchor, 1); + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, anchor, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.fragmentsPerLane); + Value step = getOrCreateIndexConstant(state.constantFolder, anchor, 1); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init}); @@ -1531,7 +1527,7 @@ Value buildProjectedPackedPayload(MaterializerState& state, Value fragmentIndex = loop.getInductionVar(); Value acc = body->getArgument(1); - Value fragmentsPerLane = createIndexConstant(state, anchor, descriptor.fragmentsPerLane); + Value fragmentsPerLane = getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.fragmentsPerLane); Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult(); Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult(); @@ -1562,13 +1558,14 @@ void appendProjectedScalarSendLoop(MaterializerState& state, state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); if (channelIds.size() == 1) { - Value channelId = createIndexConstant(state, sourceClass.op, channelIds.front()); - Value sourceCoreId = createIndexConstant(state, sourceClass.op, sourceCoreIds.front()); - Value targetCoreId = createIndexConstant(state, sourceClass.op, targetCoreIds.front()); - Value laneIndex = createIndexConstant(state, sourceClass.op, 0); + Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, channelIds.front()); + Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, sourceCoreIds.front()); + Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, targetCoreIds.front()); + Value laneIndex = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); Value sendPayload; if (descriptor.fragmentsPerLane == 1) { - Value offset = createIndexConstant(state, sourceClass.op, descriptor.laneMajorSourceDim0Offsets.front()); + Value offset = + getOrCreateIndexConstant(state.constantFolder, sourceClass.op, descriptor.laneMajorSourceDim0Offsets.front()); sendPayload = createDim0ExtractSlice(state, loc, payload, offset, descriptor.fragmentType.getDimSize(0)); } else { @@ -1579,9 +1576,9 @@ void appendProjectedScalarSendLoop(MaterializerState& state, return; } - Value lowerBound = createIndexConstant(state, sourceClass.op, 0); - Value upperBound = createIndexConstant(state, sourceClass.op, static_cast(channelIds.size())); - Value step = createIndexConstant(state, sourceClass.op, 1); + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(channelIds.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); @@ -1645,9 +1642,9 @@ Value appendScalarReceive(MaterializerState& state, assert(!targetClass.isBatch && "scalar receive helper expects a scalar target class"); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value channelIdValue = createIndexConstant(state, targetClass.op, channelId); - Value sourceCoreIdValue = createIndexConstant(state, targetClass.op, sourceCoreId); - Value targetCoreIdValue = createIndexConstant(state, targetClass.op, targetCoreId); + Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, channelId); + Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, sourceCoreId); + Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, targetCoreId); return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue) .getOutput(); } @@ -2132,9 +2129,9 @@ FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& Value init = tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult(); - Value lowerBound = createIndexConstant(state, targetClass.op, 0); - Value upperBound = createIndexConstant(state, targetClass.op, static_cast(keys.size())); - Value step = createIndexConstant(state, targetClass.op, 1); + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(keys.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init}); @@ -2198,9 +2195,9 @@ FailureOr insertDeferredLocalPackedScalarRunIntoWholeBatch(MaterializerSt SmallVector resultIndices {run.resultIndex}; - Value lowerBound = createIndexConstant(state, targetClass.op, 0); - Value upperBound = createIndexConstant(state, targetClass.op, static_cast(keys.size())); - Value step = createIndexConstant(state, targetClass.op, 1); + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(keys.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination}); @@ -2262,9 +2259,10 @@ FailureOr insertDeferredPackedScalarRunIntoWholeBatch(MaterializerState& if (outputOffsets.size() != run.channelIds.size()) return failure(); - Value lowerBound = createIndexConstant(state, targetClass.op, 0); - Value upperBound = createIndexConstant(state, targetClass.op, static_cast(run.channelIds.size())); - Value step = createIndexConstant(state, targetClass.op, 1); + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.channelIds.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination}); @@ -2343,9 +2341,9 @@ FailureOr insertPackedScalarRunIntoWholeBatch(MaterializerState& state, slotRowOffsets.push_back(static_cast(slotKey->instance.laneStart) * plan.rowsPerLane); } - Value lowerBound = createIndexConstant(state, targetClass.op, 0); - Value upperBound = createIndexConstant(state, targetClass.op, static_cast(run.slots.size())); - Value step = createIndexConstant(state, targetClass.op, 1); + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.slots.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination}); @@ -2507,7 +2505,7 @@ FailureOr emitWholeBatchAssemblyPlan(MaterializerState& state, state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); int64_t rowOffset = static_cast(fragment.key.instance.laneStart) * plan.rowsPerLane; - Value outputOffset = createIndexConstant(state, targetClass.op, rowOffset); + Value outputOffset = getOrCreateIndexConstant(state.constantFolder, targetClass.op, rowOffset); result = insertFragmentIntoWholeBatch(state, fragment.fragment, result, outputOffset, loc); } @@ -3050,7 +3048,7 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta const ComputeInstance& instance = run.front().peers.front(); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - Value laneValue = createIndexConstant(state, targetClass.op, instance.laneStart); + Value laneValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, instance.laneStart); return cloneBatchBodyForLane(state, targetClass, instance, laneValue, group.resultIndices); } @@ -3087,9 +3085,9 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta laneStarts.push_back(instance.laneStart); } - Value lowerBound = createIndexConstant(state, targetClass.op, 0); - Value upperBound = createIndexConstant(state, targetClass.op, static_cast(run.size())); - Value step = createIndexConstant(state, targetClass.op, 1); + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange(initValues)); @@ -3563,9 +3561,9 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, if (failed(buildBatchRunSendPlans(state, targetClass, run, group, sendPlans))) return failure(); - Value lowerBound = createIndexConstant(state, targetClass.op, 0); - Value upperBound = createIndexConstant(state, targetClass.op, static_cast(run.size())); - Value step = createIndexConstant(state, targetClass.op, 1); + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); @@ -3669,9 +3667,9 @@ Value createReceiveConcatLoop(MaterializerState& state, assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); assert(!channelIds.empty() && "expected at least one receive"); - Value lowerBound = createIndexConstant(state, anchor, 0); - Value upperBound = createIndexConstant(state, anchor, static_cast(channelIds.size())); - Value step = createIndexConstant(state, anchor, 1); + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, anchor, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, anchor, static_cast(channelIds.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, anchor, 1); state.rewriter.setInsertionPoint(insertionPoint); Value init = diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp index 14eef8d..4864851 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp @@ -2,6 +2,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "../Common.hpp" +#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -41,7 +42,7 @@ static OpFoldResult addConstantOffset(OpFoldResult baseOffset, int64_t extraOffs } auto value = cast(baseOffset); - auto cst = arith::ConstantIndexOp::create(rewriter, value.getLoc(), extraOffset); + auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), extraOffset); return arith::AddIOp::create(rewriter, value.getLoc(), value, cst).getResult(); } @@ -75,7 +76,7 @@ static SmallVector delinearizeIndexValue(Value linearIndex, Value remaining = linearIndex; for (auto [_dim, stride] : llvm::enumerate(strides)) { - auto cStride = arith::ConstantIndexOp::create(rewriter, linearIndex.getLoc(), stride); + auto cStride = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), stride); Value index = arith::DivUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride); indices.push_back(index); remaining = arith::RemUIOp::create(rewriter, linearIndex.getLoc(), remaining, cStride); @@ -90,7 +91,7 @@ static OpFoldResult addDynamicOffset(OpFoldResult baseOffset, Value extraOffset, if (integerAttr.getInt() == 0) return extraOffset; - auto cst = arith::ConstantIndexOp::create(rewriter, extraOffset.getLoc(), integerAttr.getInt()); + auto cst = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), integerAttr.getInt()); return arith::AddIOp::create(rewriter, extraOffset.getLoc(), cst, extraOffset).getResult(); } @@ -195,9 +196,9 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, if (allowLoopRewrite && numSlices > 1 && srcOffset == 0 && dstOffset == 0 && sourceType.getRank() == static_cast(copyShape.size()) && dstType.getRank() == static_cast(copyShape.size())) { - auto c0 = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 0); - auto cUpper = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), numSlices); - auto cStep = arith::ConstantIndexOp::create(rewriter, copyOp.getLoc(), 1); + auto c0 = getOrCreateIndexConstant(rewriter, copyOp, 0); + auto cUpper = getOrCreateIndexConstant(rewriter, copyOp, numSlices); + auto cStep = getOrCreateIndexConstant(rewriter, copyOp, 1); auto loop = scf::ForOp::create(rewriter, copyOp.getLoc(), c0, cUpper, cStep, ValueRange {}); rewriter.setInsertionPointToStart(loop.getBody()); @@ -284,8 +285,8 @@ struct RewriteHostSubviewLoadPattern final : OpRewritePattern +#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -22,16 +23,6 @@ namespace onnx_mlir { namespace { -static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { - if (isa(op)) - return operandIndex == 3; - if (isa(op)) - return operandIndex == 1; - if (isa(op)) - return operandIndex == 2; - return false; -} - template static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter, @@ -51,7 +42,7 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, for (OpOperand& operand : op->getOpOperands()) { Value originalValue = operand.get(); - if (!isa(originalValue.getType()) || isExplicitHostOperand(op, operand.getOperandNumber())) + if (!isa(originalValue.getType()) || isExplicitHostMemCopyOperand(op, operand.getOperandNumber())) continue; auto resolvedAddress = resolveContiguousAddress(originalValue); @@ -113,8 +104,8 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, rewriter, op->getLoc(), originalType, - getOrCreateHostIndexConstant(constantFolder, op, 0), - getOrCreateHostIndexConstant(constantFolder, op, static_cast(resolvedAddress->byteOffset) ), + getOrCreateIndexConstant(constantFolder, op, 0), + getOrCreateIndexConstant(constantFolder, op, static_cast(resolvedAddress->byteOffset) ), deviceDst, getGlobalOp.getResult(), rewriter.getI32IntegerAttr(static_cast(totalBytes))) diff --git a/src/PIM/Pass/PimCodegen/VerificationPass.cpp b/src/PIM/Pass/PimCodegen/VerificationPass.cpp index 94dec33..316648d 100644 --- a/src/PIM/Pass/PimCodegen/VerificationPass.cpp +++ b/src/PIM/Pass/PimCodegen/VerificationPass.cpp @@ -6,6 +6,7 @@ #include "llvm/ADT/STLExtras.h" +#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp" #include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" #include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" @@ -119,16 +120,6 @@ static bool isConstantGlobalView(Value value) { } } -static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { - if (isa(op)) - return operandIndex == 3; - if (isa(op)) - return operandIndex == 1; - if (isa(op)) - return operandIndex == 2; - return false; -} - static bool isCoreWeightBlockArgument(Value value) { auto blockArgument = dyn_cast(value); if (!blockArgument) @@ -361,7 +352,7 @@ private: continue; } - if (isExplicitHostOperand(&op, operandIndex)) { + if (isExplicitHostMemCopyOperand(&op, operandIndex)) { if (!isCodegenAddressableValue(operand, knowledge)) { diagnostics.report(&op, [&](Operation* illegalOp) { illegalOp->emitOpError() << "host operand #" << operandIndex