diff --git a/backend-simulators/pim/pimsim-nn b/backend-simulators/pim/pimsim-nn index 3e3442b..6d3b898 160000 --- a/backend-simulators/pim/pimsim-nn +++ b/backend-simulators/pim/pimsim-nn @@ -1 +1 @@ -Subproject commit 3e3442b66354282e600c5c45990af0e92aecf0f9 +Subproject commit 6d3b898e6b191c4446dfcc8c085ba1e50125e942 diff --git a/onnx-mlir b/onnx-mlir index 82018d7..eb54c2a 160000 --- a/onnx-mlir +++ b/onnx-mlir @@ -1 +1 @@ -Subproject commit 82018d7ce59c94bfbe9479b16538224969fa45a0 +Subproject commit eb54c2afc46d00c6b196d1f275b6bfee17e12f69 diff --git a/src/PIM/Common/CMakeLists.txt b/src/PIM/Common/CMakeLists.txt index 46f609b..9151cc4 100644 --- a/src/PIM/Common/CMakeLists.txt +++ b/src/PIM/Common/CMakeLists.txt @@ -9,6 +9,7 @@ add_pim_library(OMPimCommon IR/LoopUtils.cpp IR/ShapeUtils.cpp IR/SubviewUtils.cpp + IR/TensorSliceUtils.cpp IR/WeightUtils.cpp Support/CheckedArithmetic.cpp Support/DebugDump.cpp diff --git a/src/PIM/Common/IR/AffineUtils.cpp b/src/PIM/Common/IR/AffineUtils.cpp index 7fd3d23..c9dc72d 100644 --- a/src/PIM/Common/IR/AffineUtils.cpp +++ b/src/PIM/Common/IR/AffineUtils.cpp @@ -69,6 +69,15 @@ Value affineMulConst(RewriterBase& rewriter, Location loc, Value value, int64_t return createOrFoldAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value}, constantAnchor); } +Value affineAddConst(RewriterBase& rewriter, Location loc, Value value, int64_t offset, Operation* constantAnchor) { + assert(constantAnchor && "expected a valid constant anchor"); + if (offset == 0) + return value; + + AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext()); + return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor); +} + Value affineModConst(RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) { assert(constantAnchor && "expected a valid constant anchor"); assert(divisor > 0 && "expected a positive affine.mod divisor"); @@ -90,6 +99,34 @@ Value affineFloorDivConst( return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor); } +Value affineAddModConst( + RewriterBase& rewriter, Location loc, Value value, int64_t offset, int64_t divisor, Operation* constantAnchor) { + assert(constantAnchor && "expected a valid constant anchor"); + assert(divisor > 0 && "expected a positive affine.mod divisor"); + if (divisor == 1) + return getOrCreateIndexConstant(rewriter, constantAnchor, 0); + + AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext()); + AffineExpr expr = d0; + if (offset != 0) + expr = expr + offset; + return createOrFoldAffineApply(rewriter, loc, expr % divisor, ValueRange {value}, constantAnchor); +} + +Value affineAddFloorDivConst( + RewriterBase& rewriter, Location loc, Value value, int64_t offset, int64_t divisor, Operation* constantAnchor) { + assert(constantAnchor && "expected a valid constant anchor"); + assert(divisor > 0 && "expected a positive affine.floor_div divisor"); + if (divisor == 1) + return offset == 0 ? value : affineAddConst(rewriter, loc, value, offset, constantAnchor); + + AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext()); + AffineExpr expr = d0; + if (offset != 0) + expr = expr + offset; + return createOrFoldAffineApply(rewriter, loc, expr.floorDiv(divisor), ValueRange {value}, constantAnchor); +} + FailureOr evaluateAffineExpr(AffineExpr expr, ArrayRef dims, ArrayRef symbols) { if (auto constant = dyn_cast(expr)) return constant.getValue(); diff --git a/src/PIM/Common/IR/AffineUtils.hpp b/src/PIM/Common/IR/AffineUtils.hpp index ed2d585..d9d3f58 100644 --- a/src/PIM/Common/IR/AffineUtils.hpp +++ b/src/PIM/Common/IR/AffineUtils.hpp @@ -29,6 +29,12 @@ mlir::Value affineMulConst(mlir::RewriterBase& rewriter, int64_t multiplier, mlir::Operation* constantAnchor); +mlir::Value affineAddConst(mlir::RewriterBase& rewriter, + mlir::Location loc, + mlir::Value value, + int64_t offset, + mlir::Operation* constantAnchor); + mlir::Value affineModConst(mlir::RewriterBase& rewriter, mlir::Location loc, mlir::Value value, @@ -41,6 +47,20 @@ mlir::Value affineFloorDivConst(mlir::RewriterBase& rewriter, int64_t divisor, mlir::Operation* constantAnchor); +mlir::Value affineAddModConst(mlir::RewriterBase& rewriter, + mlir::Location loc, + mlir::Value value, + int64_t offset, + int64_t divisor, + mlir::Operation* constantAnchor); + +mlir::Value affineAddFloorDivConst(mlir::RewriterBase& rewriter, + mlir::Location loc, + mlir::Value value, + int64_t offset, + int64_t divisor, + mlir::Operation* constantAnchor); + llvm::FailureOr evaluateAffineExpr(mlir::AffineExpr expr, llvm::ArrayRef dims, llvm::ArrayRef symbols = {}); diff --git a/src/PIM/Common/IR/ShapeUtils.cpp b/src/PIM/Common/IR/ShapeUtils.cpp index 3ae9b0a..14da6a1 100644 --- a/src/PIM/Common/IR/ShapeUtils.cpp +++ b/src/PIM/Common/IR/ShapeUtils.cpp @@ -218,6 +218,14 @@ getTransposePermutationChecked(std::optional permAttr, int64_t return permutation; } +llvm::SmallVector getStaticIndexAttrs(mlir::Builder& builder, llvm::ArrayRef values) { + llvm::SmallVector attrs; + attrs.reserve(values.size()); + for (int64_t value : values) + attrs.push_back(builder.getIndexAttr(value)); + return attrs; +} + llvm::SmallVector getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank) { return llvm::SmallVector(rank, rewriter.getIndexAttr(1)); } diff --git a/src/PIM/Common/IR/ShapeUtils.hpp b/src/PIM/Common/IR/ShapeUtils.hpp index b2d567e..50ea76c 100644 --- a/src/PIM/Common/IR/ShapeUtils.hpp +++ b/src/PIM/Common/IR/ShapeUtils.hpp @@ -101,6 +101,8 @@ llvm::SmallVector invertPermutation(mlir::ArrayRef permutation mlir::FailureOr> getTransposePermutationChecked(std::optional permAttr, int64_t rank); +llvm::SmallVector getStaticIndexAttrs(mlir::Builder& builder, llvm::ArrayRef values); + llvm::SmallVector getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank); llvm::SmallVector getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank); diff --git a/src/PIM/Common/IR/TensorSliceUtils.cpp b/src/PIM/Common/IR/TensorSliceUtils.cpp new file mode 100644 index 0000000..447036e --- /dev/null +++ b/src/PIM/Common/IR/TensorSliceUtils.cpp @@ -0,0 +1,71 @@ +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "src/Accelerators/PIM/Common/IR/TensorSliceUtils.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +Value extractAxisSlice( + PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) { + auto sourceType = cast(source.getType()); + SmallVector resultShape(sourceType.getShape()); + resultShape[axis] = size; + auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding()); + + SmallVector offsets = getZeroOffsets(rewriter, sourceType.getRank()); + SmallVector sizes = getStaticSizes(rewriter, sourceType.getShape()); + offsets[axis] = rewriter.getIndexAttr(offset); + sizes[axis] = rewriter.getIndexAttr(size); + return tensor::ExtractSliceOp::create( + rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank())) + .getResult(); +} + +Value extractStaticSliceOrIdentity(RewriterBase& rewriter, + Location loc, + Value source, + RankedTensorType resultType, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + auto sourceType = cast(source.getType()); + size_t rank = static_cast(sourceType.getRank()); + + bool isIdentitySlice = + sourceType == resultType && sourceType.hasStaticShape() && offsets.size() == rank && sizes.size() == rank + && strides.size() == rank; + if (isIdentitySlice) { + ArrayRef sourceShape = sourceType.getShape(); + for (auto [dim, offset, size, stride] : llvm::zip_equal(sourceShape, offsets, sizes, strides)) { + std::optional staticOffset = mlir::getConstantIntValue(offset); + std::optional staticSize = mlir::getConstantIntValue(size); + std::optional staticStride = mlir::getConstantIntValue(stride); + if (!staticOffset || !staticSize || !staticStride || *staticOffset != 0 || *staticSize != dim + || *staticStride != 1) { + isIdentitySlice = false; + break; + } + } + } + + if (isIdentitySlice) + return source; + + return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult(); +} + +Value insertStaticSlice( + PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef offsets) { + auto sourceType = cast(source.getType()); + return tensor::InsertSliceOp::create(rewriter, + loc, + source, + dest, + offsets, + getStaticSizes(rewriter, sourceType.getShape()), + getUnitStrides(rewriter, sourceType.getRank())) + .getResult(); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Common/IR/TensorSliceUtils.hpp b/src/PIM/Common/IR/TensorSliceUtils.hpp new file mode 100644 index 0000000..3bc31d1 --- /dev/null +++ b/src/PIM/Common/IR/TensorSliceUtils.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" + +namespace onnx_mlir { + +mlir::Value extractAxisSlice( + mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size); + +mlir::Value extractStaticSliceOrIdentity(mlir::RewriterBase& rewriter, + mlir::Location loc, + mlir::Value source, + mlir::RankedTensorType resultType, + llvm::ArrayRef offsets, + llvm::ArrayRef sizes, + llvm::ArrayRef strides); + +mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter, + mlir::Location loc, + mlir::Value source, + mlir::Value dest, + llvm::ArrayRef offsets); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp index 301d7c7..d88cf01 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp @@ -5,5 +5,6 @@ #include "MatrixProductLowering.hpp" #include "ShapeTilingUtils.hpp" #include "WeightMaterialization.hpp" +#include "src/Accelerators/PIM/Common/IR/TensorSliceUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp index 3f8d2c6..c76a629 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp @@ -77,65 +77,4 @@ sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, PatternRewriter& rewri return slicesPerCore; } -Value extractAxisSlice( - PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) { - auto sourceType = cast(source.getType()); - SmallVector resultShape(sourceType.getShape()); - resultShape[axis] = size; - auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding()); - - SmallVector offsets = getZeroOffsets(rewriter, sourceType.getRank()); - SmallVector sizes = getStaticSizes(rewriter, sourceType.getShape()); - offsets[axis] = rewriter.getIndexAttr(offset); - sizes[axis] = rewriter.getIndexAttr(size); - return tensor::ExtractSliceOp::create( - rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank())) - .getResult(); -} - -Value extractStaticSliceOrIdentity(RewriterBase& rewriter, - Location loc, - Value source, - RankedTensorType resultType, - ArrayRef offsets, - ArrayRef sizes, - ArrayRef strides) { - auto sourceType = cast(source.getType()); - size_t rank = static_cast(sourceType.getRank()); - - bool isIdentitySlice = - sourceType == resultType && sourceType.hasStaticShape() && offsets.size() == rank && sizes.size() == rank - && strides.size() == rank; - if (isIdentitySlice) { - ArrayRef sourceShape = sourceType.getShape(); - for (auto [dim, offset, size, stride] : llvm::zip_equal(sourceShape, offsets, sizes, strides)) { - std::optional staticOffset = mlir::getConstantIntValue(offset); - std::optional staticSize = mlir::getConstantIntValue(size); - std::optional staticStride = mlir::getConstantIntValue(stride); - if (!staticOffset || !staticSize || !staticStride || *staticOffset != 0 || *staticSize != dim || *staticStride != 1) { - isIdentitySlice = false; - break; - } - } - } - - if (isIdentitySlice) - return source; - - return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, strides).getResult(); -} - -Value insertStaticSlice( - PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef offsets) { - auto sourceType = cast(source.getType()); - return tensor::InsertSliceOp::create(rewriter, - loc, - source, - dest, - offsets, - getStaticSizes(rewriter, sourceType.getShape()), - getUnitStrides(rewriter, sourceType.getRank())) - .getResult(); -} - } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp index d154654..4fb9021 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp @@ -28,21 +28,4 @@ llvm::SmallVector sliceVector(const mlir::Value& vectorToSlice, llvm::DenseMap> sliceVectorPerCrossbarPerCore( const mlir::Value& vectorToSlice, mlir::PatternRewriter& rewriter, mlir::Location loc); -mlir::Value extractAxisSlice( - mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size); - -mlir::Value extractStaticSliceOrIdentity(mlir::RewriterBase& rewriter, - mlir::Location loc, - mlir::Value source, - mlir::RankedTensorType resultType, - llvm::ArrayRef offsets, - llvm::ArrayRef sizes, - llvm::ArrayRef strides); - -mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter, - mlir::Location loc, - mlir::Value source, - mlir::Value dest, - llvm::ArrayRef offsets); - } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 7b7d992..e41583a 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -1184,48 +1184,6 @@ static Value createZeroPaddedTensor(Value value, return padOp.getResult(); } -static Value affineAddConst( - PatternRewriter& rewriter, Location loc, Value value, int64_t offset, Operation* constantAnchor) { - if (offset == 0) - return value; - - MLIRContext* context = rewriter.getContext(); - AffineExpr d0 = getAffineDimExpr(0, context); - return createOrFoldAffineApply(rewriter, loc, d0 + offset, ValueRange {value}, constantAnchor); -} - -static Value affineMulConst( - PatternRewriter& rewriter, Location loc, Value value, int64_t factor, Operation* constantAnchor) { - if (factor == 1) - return value; - - MLIRContext* context = rewriter.getContext(); - AffineExpr d0 = getAffineDimExpr(0, context); - return createOrFoldAffineApply(rewriter, loc, d0 * factor, ValueRange {value}, constantAnchor); -} - -static Value affineFloorDivConst( - PatternRewriter& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) { - assert(divisor > 0 && "expected positive affine floordiv divisor"); - if (divisor == 1) - return value; - - MLIRContext* context = rewriter.getContext(); - AffineExpr d0 = getAffineDimExpr(0, context); - return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor); -} - -static Value affineModConst( - PatternRewriter& rewriter, Location loc, Value value, int64_t modulus, Operation* constantAnchor) { - assert(modulus > 0 && "expected positive affine mod divisor"); - if (modulus == 1) - return getOrCreateIndexConstant(rewriter, constantAnchor, 0); - - MLIRContext* context = rewriter.getContext(); - AffineExpr d0 = getAffineDimExpr(0, context); - return createOrFoldAffineApply(rewriter, loc, d0 % modulus, ValueRange {value}, constantAnchor); -} - static Value createConvInputPatch(Value input, RankedTensorType patchType, Value batchIndex, @@ -2316,11 +2274,10 @@ static Value createIm2colRows(const ConvLoweringState& state, ValueRange {im2colInit}, [&](OpBuilder&, Location nestedLoc, Value patchIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { Value im2colAcc = iterArgs.front(); - Value globalPatchIndex = affineAddConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, anchorOp); Value batchIndex = - affineFloorDivConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp); + affineAddFloorDivConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, plan.numPatchesPerBatch, anchorOp); Value batchPatchIndex = - affineModConst(rewriter, nestedLoc, globalPatchIndex, plan.numPatchesPerBatch, anchorOp); + affineAddModConst(rewriter, nestedLoc, patchIndex, plan.chunkStart, plan.numPatchesPerBatch, anchorOp); Value outHeightIndex = affineFloorDivConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp); Value outWidthIndex = affineModConst(rewriter, nestedLoc, batchPatchIndex, state.outWidth, anchorOp); Value inputHeightOffset = diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.cpp b/src/PIM/Conversion/SpatialToPim/Patterns.cpp index 3901d40..68ede56 100644 --- a/src/PIM/Conversion/SpatialToPim/Patterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns.cpp @@ -15,22 +15,6 @@ namespace raptor { } // namespace raptor -static SmallVector getStaticIndexAttrs(Builder& builder, ArrayRef values) { - SmallVector attrs; - attrs.reserve(values.size()); - for (int64_t value : values) - attrs.push_back(builder.getIndexAttr(value)); - return attrs; -} - -static SmallVector getUnitStrides(Builder& builder, int64_t rank) { - SmallVector strides; - strides.reserve(rank); - for (int64_t dim = 0; dim < rank; ++dim) - strides.push_back(builder.getIndexAttr(1)); - return strides; -} - struct LowerFragmentAssemblyBlueprintPattern : OpConversionPattern { using OpConversionPattern::OpConversionPattern; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index e3b6c71..8f1c117 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -33,9 +33,9 @@ #include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/TensorSliceUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir;