From 819d8af0f7657129be2f253381c94fc115219443 Mon Sep 17 00:00:00 2001 From: ilgeco Date: Fri, 29 May 2026 15:57:13 +0200 Subject: [PATCH] Refactor + ReduceMean batched --- src/PIM/Common/IR/ConstantUtils.cpp | 22 +- src/PIM/Common/IR/ConstantUtils.hpp | 22 +- .../Conversion/ONNXToSpatial/CMakeLists.txt | 2 + .../ONNXToSpatial/Common/AttributeUtils.cpp | 23 ++ .../ONNXToSpatial/Common/AttributeUtils.hpp | 18 ++ .../ONNXToSpatial/Common/Common.hpp | 2 + .../Common/ComputeRegionBuilder.hpp | 101 +++++++ .../ONNXToSpatial/Common/IndexingUtils.cpp | 104 +++++++ .../ONNXToSpatial/Common/IndexingUtils.hpp | 45 +++ .../ONNXToSpatial/Common/ShapeTilingUtils.cpp | 131 ++++++++- .../ONNXToSpatial/Common/ShapeTilingUtils.hpp | 48 ++++ .../ONNXToSpatial/Patterns/Math/Conv.cpp | 18 +- .../ONNXToSpatial/Patterns/Math/Gemm.cpp | 194 +++++-------- .../ONNXToSpatial/Patterns/Math/MatMul.cpp | 64 +---- .../Patterns/Math/ReduceMean.cpp | 270 +++++++++++++----- .../ONNXToSpatial/Patterns/NN/Pool.cpp | 40 +-- .../ONNXToSpatial/Patterns/NN/Softmax.cpp | 45 +-- .../ONNXToSpatial/Patterns/Post.cpp | 208 ++++++-------- .../ONNXToSpatial/Patterns/Tensor/Gather.cpp | 35 +-- .../ONNXToSpatial/Patterns/Tensor/Reshape.cpp | 19 +- .../ONNXToSpatial/Patterns/Tensor/Split.cpp | 31 +- .../Patterns/Tensor/Transpose.cpp | 25 +- .../SpatialToPim/CoreLoweringPatterns.cpp | 8 +- .../SpatialToPim/ReturnPathNormalization.cpp | 6 +- .../SpatialToPim/SpatialToPimPass.cpp | 8 +- .../MaterializeMergeSchedule.cpp | 4 +- .../MaterializeHostConstantsPass.cpp | 4 +- 27 files changed, 929 insertions(+), 568 deletions(-) create mode 100644 src/PIM/Conversion/ONNXToSpatial/Common/AttributeUtils.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Common/AttributeUtils.hpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.cpp create mode 100644 src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.hpp diff --git a/src/PIM/Common/IR/ConstantUtils.cpp b/src/PIM/Common/IR/ConstantUtils.cpp index 59357f1..31e2ea3 100644 --- a/src/PIM/Common/IR/ConstantUtils.cpp +++ b/src/PIM/Common/IR/ConstantUtils.cpp @@ -28,7 +28,7 @@ Block* getHostConstantBlock(Operation* anchorOp) { return anchorOp->getBlock(); } -Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, OperationFolder& folder) { +Value getOrCreateHostConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) { assert(anchorOp && "expected a valid anchor operation"); Block* hostBlock = getHostConstantBlock(anchorOp); for (Operation& op : *hostBlock) { @@ -42,7 +42,7 @@ Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, O return folder.getOrCreateConstant(hostBlock, arithDialect, value, type); } -Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, RewriterBase& rewriter) { +Value getOrCreateHostConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) { assert(anchorOp && "expected a valid anchor operation"); Block* hostBlock = getHostConstantBlock(anchorOp); for (Operation& op : *hostBlock) { @@ -57,28 +57,28 @@ Value getOrCreateHostConstant(Operation* anchorOp, Attribute value, Type type, R return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast(value)).getResult(); } -Value getOrCreateHostConstantLike(arith::ConstantOp constantOp, OperationFolder& folder) { - return getOrCreateHostConstant(constantOp.getOperation(), constantOp.getValue(), constantOp.getType(), folder); +Value getOrCreateHostConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) { + return getOrCreateHostConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType()); } -Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, OperationFolder& folder) { +Value getOrCreateHostIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) { Builder builder(anchorOp->getContext()); - return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), folder); + return getOrCreateHostConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType() ); } -Value getOrCreateHostIndexConstant(Operation* anchorOp, int64_t value, RewriterBase& rewriter) { +Value getOrCreateHostIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) { Builder builder(anchorOp->getContext()); - return getOrCreateHostConstant(anchorOp, builder.getIndexAttr(value), builder.getIndexType(), rewriter); + return getOrCreateHostConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType()); } Value getOrCreateHostI32Constant(Operation* anchorOp, int32_t value, OperationFolder& folder) { Builder builder(anchorOp->getContext()); - return getOrCreateHostConstant(anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type(), folder); + return getOrCreateHostConstant(folder, anchorOp, builder.getI32IntegerAttr(value), builder.getI32Type() ); } Value getOrCreateHostI64Constant(Operation* anchorOp, int64_t value, OperationFolder& folder) { Builder builder(anchorOp->getContext()); - return getOrCreateHostConstant(anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type(), folder); + return getOrCreateHostConstant(folder, anchorOp, builder.getI64IntegerAttr(value), builder.getI64Type() ); } Value createAffineApplyOrFoldedConstant( @@ -95,7 +95,7 @@ Value createAffineApplyOrFoldedConstant( SmallVector foldedResults; if (succeeded(map.constantFold(operandConstants, foldedResults))) { if (auto constantResult = dyn_cast(foldedResults.front())) - return getOrCreateHostIndexConstant(anchorOp, constantResult.getInt(), rewriter); + return getOrCreateHostIndexConstant(rewriter, anchorOp, constantResult.getInt()); } return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult(); diff --git a/src/PIM/Common/IR/ConstantUtils.hpp b/src/PIM/Common/IR/ConstantUtils.hpp index dc03959..c241fc9 100644 --- a/src/PIM/Common/IR/ConstantUtils.hpp +++ b/src/PIM/Common/IR/ConstantUtils.hpp @@ -10,25 +10,25 @@ namespace onnx_mlir { mlir::Block* getHostConstantBlock(mlir::Operation* anchorOp); -mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp, +mlir::Value getOrCreateHostConstant(mlir::OperationFolder& folder, + mlir::Operation* anchorOp, mlir::Attribute value, - mlir::Type type, - mlir::OperationFolder& folder); + mlir::Type type); -mlir::Value getOrCreateHostConstant(mlir::Operation* anchorOp, +mlir::Value getOrCreateHostConstant(mlir::RewriterBase& rewriter, + mlir::Operation* anchorOp, mlir::Attribute value, - mlir::Type type, - mlir::RewriterBase& rewriter); + mlir::Type type); -mlir::Value getOrCreateHostConstantLike(mlir::arith::ConstantOp constantOp, mlir::OperationFolder& folder); +mlir::Value getOrCreateHostConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp); -mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder); +mlir::Value getOrCreateHostIndexConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value); -mlir::Value getOrCreateHostIndexConstant(mlir::Operation* anchorOp, int64_t value, mlir::RewriterBase& rewriter); +mlir::Value getOrCreateHostIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value); -mlir::Value getOrCreateHostI32Constant(mlir::Operation* anchorOp, int32_t value, mlir::OperationFolder& folder); +mlir::Value getOrCreateHostI32Constant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int32_t value); -mlir::Value getOrCreateHostI64Constant(mlir::Operation* anchorOp, int64_t value, mlir::OperationFolder& folder); +mlir::Value getOrCreateHostI64Constant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value); mlir::Value createAffineApplyOrFoldedConstant(mlir::RewriterBase& rewriter, mlir::Location loc, diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index 0e6f192..0b7e8cc 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -25,7 +25,9 @@ add_pim_library(OMONNXToSpatial Patterns/Tensor/Split.cpp Patterns/Tensor/Transpose.cpp ONNXToSpatialPass.cpp + Common/AttributeUtils.cpp Common/ComputeRegionBuilder.cpp + Common/IndexingUtils.cpp Common/ShapeTilingUtils.cpp Common/WeightMaterialization.cpp diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/AttributeUtils.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/AttributeUtils.cpp new file mode 100644 index 0000000..63e7817 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Common/AttributeUtils.cpp @@ -0,0 +1,23 @@ +#include "AttributeUtils.hpp" + +#include "mlir/IR/BuiltinAttributes.h" + +using namespace mlir; + +namespace onnx_mlir { + +int64_t getI64Attr(ArrayAttr attr, size_t index) { return cast(attr[index]).getInt(); } + +int64_t getOptionalI64Attr(std::optional attr, size_t index, int64_t defaultValue) { + return attr ? getI64Attr(*attr, index) : defaultValue; +} + +llvm::SmallVector getI64ArrayAttrValues(ArrayAttr attr) { + llvm::SmallVector values; + values.reserve(attr.size()); + for (Attribute value : attr) + values.push_back(cast(value).getInt()); + return values; +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/AttributeUtils.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/AttributeUtils.hpp new file mode 100644 index 0000000..b14fc37 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Common/AttributeUtils.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "mlir/IR/BuiltinAttributes.h" + +#include "llvm/ADT/SmallVector.h" + +#include +#include + +namespace onnx_mlir { + +int64_t getI64Attr(mlir::ArrayAttr attr, size_t index); + +int64_t getOptionalI64Attr(std::optional attr, size_t index, int64_t defaultValue); + +llvm::SmallVector getI64ArrayAttrValues(mlir::ArrayAttr attr); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp index 9099e29..35ef47e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/Common.hpp @@ -1,6 +1,8 @@ #pragma once +#include "AttributeUtils.hpp" #include "ComputeRegionBuilder.hpp" +#include "IndexingUtils.hpp" #include "ShapeTilingUtils.hpp" #include "WeightMaterialization.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp index 7355c54..e503484 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp @@ -7,9 +7,13 @@ #include #include +#include #include #include +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" namespace onnx_mlir { @@ -49,6 +53,13 @@ using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult:: template using InvokeWithValueRangeResultT = std::invoke_result_t; +struct SpatComputeBatchBodyArgs { + mlir::Value lane; + mlir::ValueRange weights; + mlir::ValueRange inputs; + mlir::ValueRange outputs; +}; + } // namespace detail template @@ -159,6 +170,96 @@ auto createSpatCompute(RewriterT& rewriter, } } +template +auto createSpatComputeBatch(RewriterT& rewriter, + mlir::Location loc, + mlir::TypeRange resultTypes, + int64_t laneCount, + mlir::ValueRange weights, + mlir::ValueRange inputs, + BodyFn&& body) { + if (laneCount <= 0 || laneCount > std::numeric_limits::max()) + return mlir::FailureOr(mlir::failure()); + + auto batchOp = spatial::SpatComputeBatch::create( + rewriter, loc, resultTypes, rewriter.getI32IntegerAttr(static_cast(laneCount)), weights, inputs); + + mlir::SmallVector blockArgTypes {rewriter.getIndexType()}; + mlir::SmallVector blockArgLocs {loc}; + blockArgTypes.reserve(1 + weights.size() + inputs.size() + resultTypes.size()); + blockArgLocs.reserve(1 + weights.size() + inputs.size() + resultTypes.size()); + for (mlir::Value weight : weights) { + blockArgTypes.push_back(weight.getType()); + blockArgLocs.push_back(weight.getLoc()); + } + for (mlir::Value input : inputs) { + blockArgTypes.push_back(input.getType()); + blockArgLocs.push_back(input.getLoc()); + } + for (mlir::Type resultType : resultTypes) { + blockArgTypes.push_back(resultType); + blockArgLocs.push_back(loc); + } + + auto* block = + rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), mlir::TypeRange(blockArgTypes), blockArgLocs); + rewriter.setInsertionPointToStart(block); + + detail::SpatComputeBatchBodyArgs args { + block->getArgument(0), + mlir::ValueRange(block->getArguments()).slice(1, weights.size()), + mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()), + mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size()) + }; + + using BodyResult = std::invoke_result_t; + if constexpr (std::is_same_v) { + std::forward(body)(args); + rewriter.setInsertionPointAfter(batchOp); + return mlir::FailureOr(batchOp); + } + else { + auto bodyResult = std::forward(body)(args); + if (mlir::failed(bodyResult)) { + rewriter.setInsertionPointAfter(batchOp); + rewriter.eraseOp(batchOp); + return mlir::FailureOr(mlir::failure()); + } + rewriter.setInsertionPointAfter(batchOp); + return mlir::FailureOr(batchOp); + } +} + +inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter, + mlir::Location loc, + mlir::Value source, + mlir::Value dest, + mlir::ArrayRef offsets, + mlir::ArrayRef sizes, + mlir::ArrayRef strides) { + auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc); + rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); + mlir::tensor::ParallelInsertSliceOp::create(rewriter, loc, source, dest, offsets, sizes, strides); +} + +template +mlir::Value materializeOrComputeUnary(mlir::Value input, + mlir::RankedTensorType resultType, + mlir::PatternRewriter& rewriter, + mlir::Location loc, + BodyFn&& build) { + auto&& buildFn = build; + if (isCompileTimeComputable(input)) + return buildFn(input); + + auto computeOp = + createSpatCompute<1>(rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) { + mlir::Value result = buildFn(computeInput); + spatial::SpatYieldOp::create(rewriter, loc, result); + }); + return computeOp.getResult(0); +} + mlir::Value sumTensors(mlir::ArrayRef tensors, mlir::ConversionPatternRewriter& rewriter); } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.cpp new file mode 100644 index 0000000..745ac5c --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.cpp @@ -0,0 +1,104 @@ +#include "IndexingUtils.hpp" + +#include "mlir/Dialect/Arith/IR/Arith.h" + +#include "llvm/ADT/APInt.h" + +#include + +#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; } + +FailureOr normalizeAxisChecked(int64_t axis, int64_t rank) { + int64_t normalizedAxis = normalizeAxis(axis, rank); + if (normalizedAxis < 0 || normalizedAxis >= rank) + return failure(); + return normalizedAxis; +} + +int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; } + +static SmallVector normalizeAxesImpl(std::optional axesAttr, int64_t rank) { + SmallVector normalizedAxes; + if (!axesAttr) { + normalizedAxes.reserve(rank); + for (int64_t axis = 0; axis < rank; ++axis) + normalizedAxes.push_back(axis); + } + else { + normalizedAxes.reserve(axesAttr->size()); + for (Attribute attr : *axesAttr) + normalizedAxes.push_back(normalizeAxis(cast(attr).getInt(), rank)); + llvm::sort(normalizedAxes); + normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end()); + } + return normalizedAxes; +} + +SmallVector normalizeAxes(ArrayAttr axesAttr, int64_t rank) { + return normalizeAxesImpl(std::optional(axesAttr), rank); +} + +SmallVector normalizeAxes(std::optional axesAttr, int64_t rank) { + return normalizeAxesImpl(axesAttr, rank); +} + +FailureOr> normalizeAxesChecked(std::optional axesAttr, int64_t rank) { + SmallVector normalizedAxes = normalizeAxesImpl(axesAttr, rank); + for (int64_t axis : normalizedAxes) + if (axis < 0 || axis >= rank) + return failure(); + return normalizedAxes; +} + +FailureOr> normalizeAxesChecked(ArrayAttr axesAttr, int64_t rank) { + return normalizeAxesChecked(std::optional(axesAttr), rank); +} + +Value createAffineApplyOrConstant(PatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) { + AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr); + Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); + return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp); +} + +Value multiplyIndexByConstant(PatternRewriter& rewriter, Operation* anchorOp, Value value, int64_t multiplier) { + if (multiplier == 0) + return getOrCreateHostIndexConstant(rewriter, anchorOp, 0); + if (multiplier == 1) + return value; + + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + return createAffineApplyOrConstant(rewriter, anchorOp->getLoc(), d0 * multiplier, ValueRange {value}); +} + +Value modIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) { + if (divisor == 1) + return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + return createAffineApplyOrConstant(rewriter, loc, d0 % divisor, ValueRange {value}); +} + +Value floorDivIndexByConstant(PatternRewriter& rewriter, Location loc, Value value, int64_t divisor) { + if (divisor == 1) + return value; + + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}); +} + +Value getOrMaterializeIndexValue(PatternRewriter& rewriter, Location loc, OpFoldResult value) { + if (auto attr = dyn_cast(value)) + return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), cast(attr).getInt()); + return cast(value); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.hpp new file mode 100644 index 0000000..3c20806 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Common/IndexingUtils.hpp @@ -0,0 +1,45 @@ +#pragma once + +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/FoldInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/SmallVector.h" + +#include + +namespace onnx_mlir { + +int64_t normalizeAxis(int64_t axis, int64_t rank); + +mlir::FailureOr normalizeAxisChecked(int64_t axis, int64_t rank); + +int64_t normalizeIndex(int64_t index, int64_t dimSize); + +llvm::SmallVector normalizeAxes(mlir::ArrayAttr axesAttr, int64_t rank); + +llvm::SmallVector normalizeAxes(std::optional axesAttr, int64_t rank); + +mlir::FailureOr> normalizeAxesChecked(mlir::ArrayAttr axesAttr, int64_t rank); + +mlir::FailureOr> normalizeAxesChecked(std::optional axesAttr, int64_t rank); + +mlir::Value createAffineApplyOrConstant(mlir::PatternRewriter& rewriter, + mlir::Location loc, + mlir::AffineExpr expr, + mlir::ValueRange operands); + +mlir::Value +multiplyIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Operation* anchorOp, mlir::Value value, int64_t multiplier); + +mlir::Value modIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor); + +mlir::Value +floorDivIndexByConstant(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value value, int64_t divisor); + +mlir::Value getOrMaterializeIndexValue(mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::OpFoldResult value); + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp index b96541f..54c283b 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.cpp @@ -6,20 +6,21 @@ #include "llvm/ADT/SmallVector.h" #include +#include #include "ShapeTilingUtils.hpp" +#include "IndexingUtils.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { static Value getIndexValue(OpFoldResult result, ConversionPatternRewriter& rewriter, Location loc) { - if (auto attr = dyn_cast(result)) - return arith::ConstantIndexOp::create(rewriter, loc, cast(attr).getInt()).getResult(); - return cast(result); + return getOrMaterializeIndexValue(rewriter, loc, result); } static Value addIndexValues(Value lhs, Value rhs, ConversionPatternRewriter& rewriter, Location loc) { @@ -50,6 +51,84 @@ static Value multiplyIndexValue(Value value, OpFoldResult factor, ConversionPatt return arith::MulIOp::create(rewriter, loc, value, factorValue).getResult(); } +bool hasStaticPositiveShape(ArrayRef shape) { + return llvm::all_of(shape, [](int64_t dim) { return dim > 0; }); +} + +bool hasStaticPositiveShape(RankedTensorType type) { return type.hasStaticShape() && hasStaticPositiveShape(type.getShape()); } + +int64_t getStaticShapeElementCount(ArrayRef shape) { + return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies {}); +} + +int64_t getStaticShapeElementCount(RankedTensorType type) { return getStaticShapeElementCount(type.getShape()); } + +SmallVector permuteShape(ArrayRef shape, ArrayRef permutation) { + SmallVector permutedShape; + permutedShape.reserve(permutation.size()); + for (int64_t axis : permutation) + permutedShape.push_back(shape[axis]); + return permutedShape; +} + +SmallVector invertPermutation(ArrayRef permutation) { + SmallVector inversePermutation(permutation.size()); + for (auto [newIndex, oldIndex] : llvm::enumerate(permutation)) + inversePermutation[oldIndex] = static_cast(newIndex); + return inversePermutation; +} + +FailureOr> getTransposePermutationChecked(std::optional permAttr, int64_t rank) { + SmallVector permutation; + if (!permAttr) { + permutation.reserve(rank); + for (int64_t dim = rank - 1; dim >= 0; --dim) + permutation.push_back(dim); + return permutation; + } + + if (static_cast(permAttr->size()) != rank) + return failure(); + + permutation.reserve(permAttr->size()); + SmallVector seen(rank, false); + for (IntegerAttr attr : permAttr->getAsRange()) { + int64_t axis = attr.getInt(); + if (axis < 0 || axis >= rank || seen[axis]) + return failure(); + seen[axis] = true; + permutation.push_back(axis); + } + return permutation; +} + +Value transposeMaybeInCompute(Value value, + RankedTensorType resultType, + ArrayRef permutation, + PatternRewriter& rewriter, + Location loc) { + auto buildTranspose = [&](Value input) -> Value { + return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult(); + }; + return materializeOrComputeUnary(value, resultType, rewriter, loc, buildTranspose); +} + +SmallVector getUnitStrides(PatternRewriter& rewriter, int64_t rank) { + return SmallVector(rank, rewriter.getIndexAttr(1)); +} + +SmallVector getZeroOffsets(PatternRewriter& rewriter, int64_t rank) { + return SmallVector(rank, rewriter.getIndexAttr(0)); +} + +SmallVector getStaticSizes(PatternRewriter& rewriter, ArrayRef shape) { + SmallVector sizes; + sizes.reserve(shape.size()); + for (int64_t dim : shape) + sizes.push_back(rewriter.getIndexAttr(dim)); + return sizes; +} + static bool isContiguousTensorSlice(Value source, RankedTensorType resultType, ArrayRef strides) { auto sourceType = dyn_cast(source.getType()); if (!sourceType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() || sourceType.getRank() != resultType.getRank()) @@ -88,11 +167,8 @@ SmallVector sliceTensor( assert("Invalid axis" && axis < shape.size()); SmallVector strides(shape.size(), rewriter.getIndexAttr(1)); - SmallVector offsets(shape.size(), rewriter.getIndexAttr(0)); - SmallVector sizes; - sizes.reserve(shape.size()); - for (const auto size : shape) - sizes.push_back(rewriter.getIndexAttr(size)); + SmallVector offsets = getZeroOffsets(rewriter, shape.size()); + SmallVector sizes = getStaticSizes(rewriter, shape); sizes[axis] = rewriter.getIndexAttr(sliceSize); long length = shape[axis]; @@ -276,4 +352,43 @@ Value materializeContiguousTensorSlice(Value source, return buildLoopNest(buildLoopNest, 0, init); } +Value extractStaticSlice(PatternRewriter& rewriter, + Location loc, + Value source, + RankedTensorType resultType, + ArrayRef offsets) { + return tensor::ExtractSliceOp::create( + rewriter, loc, resultType, source, offsets, getStaticSizes(rewriter, resultType.getShape()), + getUnitStrides(rewriter, resultType.getRank())) + .getResult(); +} + +Value extractAxisSlice( + PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) { + auto sourceType = cast(source.getType()); + SmallVector resultShape(sourceType.getShape()); + resultShape[axis] = size; + auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding()); + + SmallVector offsets = getZeroOffsets(rewriter, sourceType.getRank()); + SmallVector sizes = getStaticSizes(rewriter, sourceType.getShape()); + offsets[axis] = rewriter.getIndexAttr(offset); + sizes[axis] = rewriter.getIndexAttr(size); + return tensor::ExtractSliceOp::create(rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank())) + .getResult(); +} + +Value insertStaticSlice( + PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef offsets) { + auto sourceType = cast(source.getType()); + return tensor::InsertSliceOp::create(rewriter, + loc, + source, + dest, + offsets, + getStaticSizes(rewriter, sourceType.getShape()), + getUnitStrides(rewriter, sourceType.getRank())) + .getResult(); +} + } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp index 908a9e7..3e90301 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/ShapeTilingUtils.hpp @@ -3,6 +3,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" @@ -11,6 +12,7 @@ #include #include +#include #include #include @@ -109,6 +111,33 @@ inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) { && lhsType.getShape() == rhsType.getShape(); } +bool hasStaticPositiveShape(mlir::ArrayRef shape); + +bool hasStaticPositiveShape(mlir::RankedTensorType type); + +int64_t getStaticShapeElementCount(mlir::ArrayRef shape); + +int64_t getStaticShapeElementCount(mlir::RankedTensorType type); + +llvm::SmallVector permuteShape(mlir::ArrayRef shape, mlir::ArrayRef permutation); + +llvm::SmallVector invertPermutation(mlir::ArrayRef permutation); + +mlir::FailureOr> getTransposePermutationChecked(std::optional permAttr, + int64_t rank); + +mlir::Value transposeMaybeInCompute(mlir::Value value, + mlir::RankedTensorType resultType, + mlir::ArrayRef permutation, + mlir::PatternRewriter& rewriter, + mlir::Location loc); + +llvm::SmallVector getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank); + +llvm::SmallVector getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank); + +llvm::SmallVector getStaticSizes(mlir::PatternRewriter& rewriter, mlir::ArrayRef shape); + /// Slices a statically shaped tensor along one axis into contiguous pieces of /// at most `sliceSize` elements. llvm::SmallVector sliceTensor(const mlir::Value& tensorToSlice, @@ -148,4 +177,23 @@ mlir::Value materializeContiguousTensorSlice(mlir::Value source, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); +mlir::Value extractStaticSlice(mlir::PatternRewriter& rewriter, + mlir::Location loc, + mlir::Value source, + mlir::RankedTensorType resultType, + llvm::ArrayRef offsets); + +mlir::Value extractAxisSlice(mlir::PatternRewriter& rewriter, + mlir::Location loc, + mlir::Value source, + int64_t axis, + int64_t offset, + int64_t size); + +mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter, + mlir::Location loc, + mlir::Value source, + mlir::Value dest, + llvm::ArrayRef offsets); + } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp index 8f3b16a..4a15427 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Conv.cpp @@ -28,8 +28,6 @@ struct ConvToGemm : OpConversionPattern { ConversionPatternRewriter& rewriter) const override; }; -static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast(arr[idx]).getInt(); } - static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) { auto biasType = cast(bias.getType()); if (biasType.getRank() != 1) @@ -615,10 +613,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, return failure(); } - const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1; - const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1; - const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1; - const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1; + const int64_t strideHeight = getOptionalI64Attr(stridesAttr, 0, 1); + const int64_t strideWidth = getOptionalI64Attr(stridesAttr, 1, 1); + const int64_t dilationHeight = getOptionalI64Attr(dilationsAttr, 0, 1); + const int64_t dilationWidth = getOptionalI64Attr(dilationsAttr, 1, 1); int64_t padHeightBegin = 0; int64_t padHeightEnd = 0; @@ -626,10 +624,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp, int64_t padWidthEnd = 0; if (padsAttr) { - padHeightBegin = getI64FromArrayAttr(*padsAttr, 0); - padWidthBegin = getI64FromArrayAttr(*padsAttr, 1); - padHeightEnd = getI64FromArrayAttr(*padsAttr, 2); - padWidthEnd = getI64FromArrayAttr(*padsAttr, 3); + padHeightBegin = getI64Attr(*padsAttr, 0); + padWidthBegin = getI64Attr(*padsAttr, 1); + padHeightEnd = getI64Attr(*padsAttr, 2); + padWidthEnd = getI64Attr(*padsAttr, 3); } else { // Compute padding from auto_pad attribute diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index bf15091..5c13ea4 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -13,7 +13,7 @@ #include #include -#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" +#include "Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp" @@ -58,47 +58,16 @@ static Value transposeForSpatial(Value value, ArrayRef permutation, ConversionPatternRewriter& rewriter, Location loc) { - if (isCompileTimeComputable(value)) - return ONNXTransposeOp::create(rewriter, loc, resultType, value, rewriter.getI64ArrayAttr(permutation)); - - auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, value, [&](Value input) { - Value transposed = ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)); - spatial::SpatYieldOp::create(rewriter, loc, transposed); - }); - return computeOp.getResult(0); -} - -static Value createIndexConstant(ConversionPatternRewriter& rewriter, int64_t value) { - Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); - return getOrCreateHostIndexConstant(anchorOp, value, rewriter); -} - -static Value -createAffineApply(ConversionPatternRewriter& rewriter, Location loc, AffineExpr expr, ValueRange operands) { - AffineMap map = AffineMap::get(/*dimCount=*/operands.size(), /*symbolCount=*/0, expr); - Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp(); - return createAffineApplyOrFoldedConstant(rewriter, loc, map, operands, anchorOp); + return transposeMaybeInCompute(value, resultType, permutation, rewriter, loc); } static Value multiplyIndexByConstant(Value value, int64_t multiplier, ConversionPatternRewriter& rewriter, Location loc) { - if (multiplier == 0) - return createIndexConstant(rewriter, 0); - if (multiplier == 1) - return value; - - MLIRContext* context = rewriter.getContext(); - AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value}); + return onnx_mlir::multiplyIndexByConstant(rewriter, value.getDefiningOp(), value, multiplier); } static Value modIndexByConstant(Value value, int64_t divisor, ConversionPatternRewriter& rewriter, Location loc) { - if (divisor == 1) - return createIndexConstant(rewriter, 0); - - MLIRContext* context = rewriter.getContext(); - AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApply(rewriter, loc, d0 % divisor, ValueRange {value}); + return onnx_mlir::modIndexByConstant(rewriter, loc, value, divisor); } static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatternRewriter& rewriter, Location loc) { @@ -108,11 +77,11 @@ static Value createGemmBatchRow(Value lane, int64_t numOutRows, ConversionPatter static Value createGemmBatchKOffset( Value lane, int64_t numOutRows, int64_t numKSlices, ConversionPatternRewriter& rewriter, Location loc) { if (numKSlices == 1) - return createIndexConstant(rewriter, 0); + return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApply( + return createAffineApplyOrConstant( rewriter, loc, (d0.floorDiv(numOutRows) % numKSlices) * crossbarSize.getValue(), ValueRange {lane}); } @@ -123,11 +92,11 @@ static Value createGemmBatchHOffset(Value lane, ConversionPatternRewriter& rewriter, Location loc) { if (numOutHSlices == 1) - return createIndexConstant(rewriter, 0); + return getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApply( + return createAffineApplyOrConstant( rewriter, loc, d0.floorDiv(numOutRows * numKSlices) * crossbarSize.getValue(), ValueRange {lane}); } @@ -303,53 +272,37 @@ static spatial::SpatComputeBatch createVmmBatch(Value a, ConversionPatternRewriter& rewriter, Location loc) { const int64_t laneCount = partialPiecesType.getDimSize(0); - auto batchOp = spatial::SpatComputeBatch::create(rewriter, - loc, - TypeRange {partialPiecesType}, - rewriter.getI32IntegerAttr(static_cast(laneCount)), - ValueRange {b}, - ValueRange {a}); + auto batchOp = createSpatComputeBatch( + rewriter, loc, TypeRange {partialPiecesType}, laneCount, ValueRange {b}, ValueRange {a}, [&](detail::SpatComputeBatchBodyArgs args) { + Value row = createGemmBatchRow(args.lane, numOutRows, rewriter, loc); + Value kOffset = createGemmBatchKOffset(args.lane, numOutRows, numKSlices, rewriter, loc); + Value hOffset = createGemmBatchHOffset(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc); - SmallVector blockArgTypes {rewriter.getIndexType(), paddedBType, aType, partialPiecesType}; - SmallVector blockArgLocs(blockArgTypes.size(), loc); - Block* body = - rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); - rewriter.setInsertionPointToEnd(body); + auto aTileType = + RankedTensorType::get({1, static_cast(crossbarSize.getValue())}, aType.getElementType()); + auto bTileType = RankedTensorType::get( + {static_cast(crossbarSize.getValue()), static_cast(crossbarSize.getValue())}, + paddedBType.getElementType()); + auto pieceType = + RankedTensorType::get({1, static_cast(crossbarSize.getValue())}, partialPiecesType.getElementType()); + Value aTile = extractATile(args.inputs.front(), row, kOffset, aTileType, rewriter, loc); - auto lane = batchOp.getLaneArgument(); - auto weight = batchOp.getWeightArgument(0); - auto input = batchOp.getInputArgument(0); - auto output = batchOp.getOutputArgument(0); - assert(lane && weight && input && output && "malformed Gemm compute_batch body"); + SmallVector bOffsets {kOffset, hOffset}; + SmallVector bSizes {rewriter.getIndexAttr(crossbarSize.getValue()), + rewriter.getIndexAttr(crossbarSize.getValue())}; + SmallVector unitStrides = getUnitStrides(rewriter, 2); + Value bTile = + tensor::ExtractSliceOp::create(rewriter, loc, bTileType, args.weights.front(), bOffsets, bSizes, unitStrides) + .getResult(); + Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult(); - Value row = createGemmBatchRow(*lane, numOutRows, rewriter, loc); - Value kOffset = createGemmBatchKOffset(*lane, numOutRows, numKSlices, rewriter, loc); - Value hOffset = createGemmBatchHOffset(*lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc); - - auto aTileType = RankedTensorType::get({1, static_cast(crossbarSize.getValue())}, aType.getElementType()); - auto bTileType = RankedTensorType::get( - {static_cast(crossbarSize.getValue()), static_cast(crossbarSize.getValue())}, - paddedBType.getElementType()); - auto pieceType = - RankedTensorType::get({1, static_cast(crossbarSize.getValue())}, partialPiecesType.getElementType()); - Value aTile = extractATile(*input, row, kOffset, aTileType, rewriter, loc); - - SmallVector bOffsets {kOffset, hOffset}; - SmallVector bSizes {rewriter.getIndexAttr(crossbarSize.getValue()), - rewriter.getIndexAttr(crossbarSize.getValue())}; - SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - Value bTile = - tensor::ExtractSliceOp::create(rewriter, loc, bTileType, *weight, bOffsets, bSizes, unitStrides).getResult(); - Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult(); - - auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc); - rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); - SmallVector pieceOffsets {*lane, rewriter.getIndexAttr(0)}; - SmallVector pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())}; - tensor::ParallelInsertSliceOp::create(rewriter, loc, piece, *output, pieceOffsets, pieceSizes, unitStrides); - - rewriter.setInsertionPointAfter(batchOp); - return batchOp; + SmallVector pieceOffsets {args.lane, rewriter.getIndexAttr(0)}; + SmallVector pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())}; + createParallelInsertSliceIntoBatchOutput( + rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, unitStrides); + }); + assert(succeeded(batchOp) && "expected Gemm VMM batch construction to succeed"); + return *batchOp; } static Value createDynamicGemmBatchRow( @@ -359,7 +312,7 @@ static Value createDynamicGemmBatchRow( MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApply(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}); + return createAffineApplyOrConstant(rewriter, loc, d0.floorDiv(numOutCols), ValueRange {lane}); } static Value createDynamicGemmBatchColumn( @@ -479,45 +432,27 @@ static spatial::SpatComputeBatch createVvdmulBatch(Value a, const int64_t numOutCols = outType.getDimSize(1); const int64_t reductionSize = aType.getDimSize(1); const int64_t laneCount = numOutRows * numOutCols; - auto batchOp = spatial::SpatComputeBatch::create(rewriter, - loc, - TypeRange {scalarPiecesType}, - rewriter.getI32IntegerAttr(static_cast(laneCount)), - ValueRange {}, - ValueRange {a, b}); + auto batchOp = createSpatComputeBatch( + rewriter, loc, TypeRange {scalarPiecesType}, laneCount, ValueRange {}, ValueRange {a, b}, [&](detail::SpatComputeBatchBodyArgs args) { + Value row = createDynamicGemmBatchRow(args.lane, numOutCols, rewriter, loc); + Value column = createDynamicGemmBatchColumn(args.lane, numOutCols, rewriter, loc); - SmallVector blockArgTypes {rewriter.getIndexType(), aType, bType, scalarPiecesType}; - SmallVector blockArgLocs(blockArgTypes.size(), loc); - Block* body = - rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); - rewriter.setInsertionPointToEnd(body); + auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType()); + auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); + Value aVector = extractDynamicGemmRowVector(args.inputs[0], row, vectorType, rewriter, loc); + Value bVector = bAlreadyTransposed + ? extractTransposedBRow(args.inputs[1], column, vectorType, rewriter, loc) + : extractDynamicGemmBColumn(args.inputs[1], column, vectorType, rewriter, loc); + Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult(); - auto lane = batchOp.getLaneArgument(); - auto inputA = batchOp.getInputArgument(0); - auto inputB = batchOp.getInputArgument(1); - auto output = batchOp.getOutputArgument(0); - assert(lane && inputA && inputB && output && "malformed dynamic Gemm compute_batch body"); - - Value row = createDynamicGemmBatchRow(*lane, numOutCols, rewriter, loc); - Value column = createDynamicGemmBatchColumn(*lane, numOutCols, rewriter, loc); - - auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType()); - auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); - Value aVector = extractDynamicGemmRowVector(*inputA, row, vectorType, rewriter, loc); - Value bVector = bAlreadyTransposed - ? extractTransposedBRow(*inputB, column, vectorType, rewriter, loc) - : extractDynamicGemmBColumn(*inputB, column, vectorType, rewriter, loc); - Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult(); - - auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc); - rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); - SmallVector outputOffsets {*lane, rewriter.getIndexAttr(0)}; - SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - SmallVector unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - tensor::ParallelInsertSliceOp::create(rewriter, loc, scalar, *output, outputOffsets, scalarSizes, unitStrides); - - rewriter.setInsertionPointAfter(batchOp); - return batchOp; + SmallVector outputOffsets {args.lane, rewriter.getIndexAttr(0)}; + SmallVector scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + SmallVector unitStrides = getUnitStrides(rewriter, 2); + createParallelInsertSliceIntoBatchOutput( + rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, unitStrides); + }); + assert(succeeded(batchOp) && "expected Gemm VVDMul batch construction to succeed"); + return *batchOp; } static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces, @@ -540,9 +475,9 @@ static spatial::SpatCompute createDynamicGemmOutputCompute(Value scalarPieces, Value biasArg = bias ? blockArgs[1] : Value(); auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType()); Value outputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult(); - Value c0 = createIndexConstant(rewriter, 0); - Value c1 = createIndexConstant(rewriter, 1); - Value cLaneCount = createIndexConstant(rewriter, laneCount); + Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); + Value cLaneCount = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount); auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit}); rewriter.setInsertionPointToStart(loop.getBody()); @@ -587,7 +522,8 @@ static Value createPartialGroupOffset(Value hSlice, Location loc) { MLIRContext* context = rewriter.getContext(); AffineExpr d0 = getAffineDimExpr(0, context); - return createAffineApply(rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice}); + return createAffineApplyOrConstant( + rewriter, loc, d0 * (numKSlices * numOutRows) + kSlice * numOutRows, ValueRange {hSlice}); } static Value extractReductionPiece(Value partialPiecesArg, @@ -684,13 +620,13 @@ static spatial::SpatCompute createReductionCompute(Value partialPieces, Value paddedOutput = outputInit; if (numOutHSlices == 1) { - Value hSlice = createIndexConstant(rewriter, 0); + Value hSlice = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); paddedOutput = buildOutputSlice(outputInit, hSlice); } else { - Value c0 = createIndexConstant(rewriter, 0); - Value c1 = createIndexConstant(rewriter, 1); - Value cOutHSlices = createIndexConstant(rewriter, numOutHSlices); + Value c0 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0); + Value c1 = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1); + Value cOutHSlices = getOrCreateHostIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices); auto hLoop = scf::ForOp::create(rewriter, loc, c0, cOutHSlices, c1, ValueRange {outputInit}); rewriter.setInsertionPointToStart(hLoop.getBody()); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index 86ffded..eb7329e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -19,14 +19,6 @@ using namespace mlir; namespace onnx_mlir { namespace { -static bool haveStaticPositiveShape(ArrayRef shape) { - return llvm::all_of(shape, [](int64_t dim) { return dim > 0; }); -} - -static int64_t getStaticShapeElementCount(ArrayRef shape) { - return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies {}); -} - static FailureOr> inferSupportedBatchShape(ArrayRef lhsBatchShape, ArrayRef rhsBatchShape) { if (lhsBatchShape.empty()) @@ -54,15 +46,7 @@ collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, Pa auto buildCollapsed = [&](Value input) -> Value { return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation); }; - - if (isCompileTimeComputable(value)) - return buildCollapsed(value); - - auto collapseCompute = - createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) { - spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input)); - }); - return collapseCompute.getResult(0); + return materializeOrComputeUnary(value, collapsedType, rewriter, loc, buildCollapsed); } static Value @@ -76,12 +60,10 @@ expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, Patt for (size_t dim = 0; dim < batchRank; ++dim) reassociation.front().push_back(static_cast(dim)); - auto expandCompute = - createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) { - Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation); - spatial::SpatYieldOp::create(rewriter, loc, expanded); - }); - return expandCompute.getResult(0); + auto buildExpanded = [&](Value input) -> Value { + return tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation).getResult(); + }; + return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded); } static Value extractBatchMatrix(Value value, @@ -100,7 +82,7 @@ static Value extractBatchMatrix(Value value, rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)}; SmallVector sizes = { rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)}; - SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + SmallVector strides = getUnitStrides(rewriter, 3); auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType()); auto buildMatrix = [&](Value input) -> Value { Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides); @@ -114,14 +96,7 @@ static Value extractBatchMatrix(Value value, }); }; - if (isCompileTimeComputable(value)) - return buildMatrix(value); - - auto batchMatrixCompute = - createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) { - spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input)); - }); - return batchMatrixCompute.getResult(0); + return materializeOrComputeUnary(value, matrixType, rewriter, loc, buildMatrix); } static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) { @@ -138,18 +113,7 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati perm = {0, 2, 1}; } - auto buildTranspose = [&](Value input) -> Value { - return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm)); - }; - - if (isCompileTimeComputable(value)) - return buildTranspose(value); - - auto transposeCompute = - createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) { - spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input)); - }); - return transposeCompute.getResult(0); + return transposeMaybeInCompute(value, transposedType, perm, rewriter, loc); } static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) { @@ -166,10 +130,11 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite perm = {0, 2, 1}; } - auto transposeCompute = createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) { - Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm)); - spatial::SpatYieldOp::create(rewriter, loc, transposed); - }); + auto transposeCompute = + createSpatCompute<1>(rewriter, loc, transposedType, {}, ValueRange {value}, [&](Value input) { + Value transposed = ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm)); + spatial::SpatYieldOp::create(rewriter, loc, transposed); + }); return transposeCompute.getResult(0); } @@ -203,8 +168,7 @@ struct MatMulToGemm : OpRewritePattern { return failure(); if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2) return failure(); - if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape()) - || !haveStaticPositiveShape(outType.getShape())) + if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType)) return failure(); SmallVector lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp index c89f06e..5181aa4 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp @@ -1,9 +1,11 @@ +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" #include +#include #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" @@ -16,26 +18,6 @@ using namespace mlir; namespace onnx_mlir { namespace { -static SmallVector normalizeAxes(ArrayAttr axesAttr, int64_t rank) { - SmallVector normalizedAxes; - if (!axesAttr) { - normalizedAxes.reserve(rank); - for (int64_t axis = 0; axis < rank; axis++) - normalizedAxes.push_back(axis); - return normalizedAxes; - } - - normalizedAxes.reserve(axesAttr.size()); - for (Attribute attr : axesAttr) { - int64_t axis = cast(attr).getInt(); - normalizedAxes.push_back(axis >= 0 ? axis : rank + axis); - } - - llvm::sort(normalizedAxes); - normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end()); - return normalizedAxes; -} - static SmallVector buildReducedAxesMask(ArrayRef axes, int64_t rank) { SmallVector reducedAxes(rank, false); for (int64_t axis : axes) { @@ -50,6 +32,181 @@ static RankedTensorType getAllOnesType(RankedTensorType inputType, Type elementT return RankedTensorType::get(SmallVector(inputType.getRank(), 1), elementType); } +static RankedTensorType getKeepdimsType(RankedTensorType inputType, Type elementType, ArrayRef reducedAxes) { + SmallVector shape; + shape.reserve(inputType.getRank()); + for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes)) + shape.push_back(isReduced ? 1 : dim); + return RankedTensorType::get(shape, elementType, inputType.getEncoding()); +} + +static RankedTensorType getCompactKeptType(RankedTensorType inputType, Type elementType, ArrayRef reducedAxes) { + SmallVector shape; + for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes)) + if (!isReduced) + shape.push_back(dim); + return RankedTensorType::get(shape, elementType, inputType.getEncoding()); +} + +static RankedTensorType getReducedSliceType(RankedTensorType inputType, ArrayRef reducedAxes) { + SmallVector shape; + shape.reserve(inputType.getRank()); + for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes)) + shape.push_back(isReduced ? dim : 1); + return RankedTensorType::get(shape, inputType.getElementType(), inputType.getEncoding()); +} + +static RankedTensorType getLanePackedKeepdimsType(int64_t laneCount, RankedTensorType leafType) { + SmallVector shape(leafType.getShape().begin(), leafType.getShape().end()); + shape.front() = laneCount; + return RankedTensorType::get(shape, leafType.getElementType(), leafType.getEncoding()); +} + +static SmallVector getKeptAxes(ArrayRef reducedAxes) { + SmallVector keptAxes; + for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) + if (!isReduced) + keptAxes.push_back(static_cast(axis)); + return keptAxes; +} + +static Value computeLaneIndex(Value lane, + int64_t stride, + int64_t dimSize, + ConversionPatternRewriter& rewriter, + Location loc) { + if (dimSize == 1) + return arith::ConstantIndexOp::create(rewriter, loc, 0); + + MLIRContext* context = rewriter.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + AffineExpr expr = d0; + if (stride != 1) + expr = expr.floorDiv(stride); + if (dimSize != 1) + expr = expr % dimSize; + return createAffineApplyOrConstant(rewriter, loc, expr, ValueRange {lane}); +} + +static FailureOr buildReduceMeanKeepdimsBatch(Value input, + ArrayRef reducedAxes, + RankedTensorType batchType, + RankedTensorType leafType, + ConversionPatternRewriter& rewriter, + Location loc) { + auto inputType = cast(input.getType()); + auto sliceType = getReducedSliceType(inputType, reducedAxes); + SmallVector keptAxes = getKeptAxes(reducedAxes); + + int64_t laneCount = 1; + SmallVector keptAxisStrides(keptAxes.size(), 1); + for (int64_t index = static_cast(keptAxes.size()) - 1; index >= 0; --index) { + keptAxisStrides[index] = laneCount; + int64_t dimSize = inputType.getDimSize(keptAxes[index]); + if (dimSize <= 0) + return failure(); + if (laneCount > std::numeric_limits::max() / dimSize) + return failure(); + laneCount *= dimSize; + } + + SmallVector sliceOffsets; + SmallVector sliceSizes; + SmallVector insertOffsets; + SmallVector insertSizes(inputType.getRank(), rewriter.getIndexAttr(1)); + SmallVector unitStrides = getUnitStrides(rewriter, inputType.getRank()); + sliceOffsets.reserve(inputType.getRank()); + sliceSizes.reserve(inputType.getRank()); + insertOffsets.reserve(inputType.getRank()); + + auto batchOp = createSpatComputeBatch( + rewriter, loc, TypeRange {batchType}, laneCount, {}, ValueRange {input}, [&](detail::SpatComputeBatchBodyArgs args) { + size_t keptAxisIndex = 0; + sliceOffsets.clear(); + sliceSizes.clear(); + insertOffsets.clear(); + for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) { + if (isReduced) { + sliceOffsets.push_back(rewriter.getIndexAttr(0)); + sliceSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(axis))); + continue; + } + + Value axisIndex = + computeLaneIndex(args.lane, keptAxisStrides[keptAxisIndex], inputType.getDimSize(axis), rewriter, loc); + ++keptAxisIndex; + sliceOffsets.push_back(axisIndex); + sliceSizes.push_back(rewriter.getIndexAttr(1)); + } + + insertOffsets.push_back(args.lane); + insertOffsets.append(inputType.getRank() - 1, rewriter.getIndexAttr(0)); + + Value slice = + tensor::ExtractSliceOp::create(rewriter, loc, sliceType, args.inputs.front(), sliceOffsets, sliceSizes, unitStrides); + Value reduced = spatial::SpatVAvgOp::create(rewriter, loc, leafType, slice).getResult(); + createParallelInsertSliceIntoBatchOutput( + rewriter, loc, reduced, args.outputs.front(), insertOffsets, insertSizes, unitStrides); + }); + if (failed(batchOp)) + return failure(); + return (*batchOp).getResult(0); +} + +static Value buildKeepdimsFromLanePackedBatch(Value batchValue, + RankedTensorType keepdimsType, + RankedTensorType compactKeptType, + ArrayRef reducedAxes, + ConversionPatternRewriter& rewriter, + Location loc) { + auto batchType = cast(batchValue.getType()); + if (batchType == keepdimsType) + return batchValue; + + SmallVector collapseToFlat {{}}; + for (int64_t axis = 0; axis < batchType.getRank(); ++axis) + collapseToFlat.front().push_back(axis); + + SmallVector expandFlatToCompact(1); + for (int64_t axis = 0; axis < compactKeptType.getRank(); ++axis) + expandFlatToCompact.front().push_back(axis); + + SmallVector expandCompactToKeepdims; + ReassociationIndices pendingLeadingReducedAxes; + for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) { + if (isReduced) { + if (expandCompactToKeepdims.empty()) + pendingLeadingReducedAxes.push_back(axis); + else + expandCompactToKeepdims.back().push_back(axis); + continue; + } + + expandCompactToKeepdims.emplace_back(); + auto& group = expandCompactToKeepdims.back(); + group.append(pendingLeadingReducedAxes.begin(), pendingLeadingReducedAxes.end()); + pendingLeadingReducedAxes.clear(); + group.push_back(axis); + } + if (!pendingLeadingReducedAxes.empty()) + expandCompactToKeepdims.back().append(pendingLeadingReducedAxes.begin(), pendingLeadingReducedAxes.end()); + + auto reshapeCompute = + createSpatCompute<1>(rewriter, loc, TypeRange {keepdimsType}, {}, ValueRange {batchValue}, [&](Value input) { + auto flatType = RankedTensorType::get({batchType.getDimSize(0)}, batchType.getElementType(), batchType.getEncoding()); + Value flat = tensor::CollapseShapeOp::create(rewriter, loc, flatType, input, collapseToFlat); + Value compact = flat; + if (compactKeptType != flatType) + compact = tensor::ExpandShapeOp::create(rewriter, loc, compactKeptType, flat, expandFlatToCompact); + Value keepdims = compact; + if (keepdimsType != compactKeptType) + keepdims = + tensor::ExpandShapeOp::create(rewriter, loc, keepdimsType, compact, expandCompactToKeepdims); + spatial::SpatYieldOp::create(rewriter, loc, keepdims); + }); + return reshapeCompute.getResult(0); +} + static SmallVector buildCollapseReassociation(ArrayRef reducedAxes) { SmallVector reassociation; ReassociationIndices currentGroup; @@ -72,56 +229,6 @@ static SmallVector buildCollapseReassociation(ArrayRef(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) { - auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x); - spatial::SpatYieldOp::create(rewriter, loc, avgOp.getResult()); - }); - return computeOp.getResult(0); -} - -static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) { - auto firstType = cast(inputs.front().getType()); - SmallVector outputShape(firstType.getShape().begin(), firstType.getShape().end()); - int64_t concatDimSize = 0; - for (Value input : inputs) - concatDimSize += cast(input.getType()).getDimSize(axis); - outputShape[axis] = concatDimSize; - auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding()); - - if (llvm::all_of(inputs, isCompileTimeComputable)) - return createSpatConcat(rewriter, loc, axis, inputs); - - auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) { - spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args)); - }); - return concatCompute.getResult(0); -} - -static Value buildReduceMeanKeepdims(Value input, - ArrayRef reducedAxes, - int64_t axis, - RankedTensorType leafType, - ConversionPatternRewriter& rewriter, - Location loc) { - int64_t rank = cast(input.getType()).getRank(); - if (axis == rank) - return createAverageCompute(input, leafType, rewriter, loc); - - if (reducedAxes[axis]) - return buildReduceMeanKeepdims(input, reducedAxes, axis + 1, leafType, rewriter, loc); - - SmallVector slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc); - SmallVector reducedSlices; - reducedSlices.reserve(slices.size()); - for (Value slice : slices) - reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc)); - - return concatValues(reducedSlices, axis, rewriter, loc); -} - static Value squeezeReducedAxes(Value keepdimsValue, RankedTensorType resultType, ArrayRef reducedAxes, @@ -156,16 +263,33 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern { auto resultType = dyn_cast(reduceMeanOp.getReduced().getType()); if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); + if (inputType.getRank() == 0) { + rewriter.replaceOp(reduceMeanOp, adaptor.getData()); + return success(); + } - SmallVector axes = normalizeAxes(reduceMeanOp.getAxesAttr(), inputType.getRank()); - SmallVector reducedAxes = buildReducedAxesMask(axes, inputType.getRank()); + auto axes = normalizeAxesChecked(reduceMeanOp.getAxesAttr(), inputType.getRank()); + if (failed(axes)) + return failure(); + SmallVector reducedAxes = buildReducedAxesMask(*axes, inputType.getRank()); if (reducedAxes.empty() && inputType.getRank() != 0) return failure(); Location loc = reduceMeanOp.getLoc(); RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType()); + RankedTensorType compactKeptType = getCompactKeptType(inputType, resultType.getElementType(), reducedAxes); + RankedTensorType keepdimsType = getKeepdimsType(inputType, resultType.getElementType(), reducedAxes); + int64_t laneCount = 1; + for (int64_t dim : compactKeptType.getShape()) + laneCount *= dim; + RankedTensorType batchType = getLanePackedKeepdimsType(laneCount, leafType); + + auto lanePackedKeepdims = + buildReduceMeanKeepdimsBatch(adaptor.getData(), reducedAxes, batchType, leafType, rewriter, loc); + if (failed(lanePackedKeepdims)) + return failure(); Value reducedKeepdims = - buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc); + buildKeepdimsFromLanePackedBatch(*lanePackedKeepdims, keepdimsType, compactKeptType, reducedAxes, rewriter, loc); if (reduceMeanOp.getKeepdims() != 0) { rewriter.replaceOp(reduceMeanOp, reducedKeepdims); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp index a84e118..80ecda4 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Pool.cpp @@ -23,28 +23,10 @@ using namespace mlir; namespace onnx_mlir { namespace { -template -static int64_t getI64(ArrayAttrT arrayAttr, size_t index) { - return cast(arrayAttr[index]).getInt(); -} - -template -static int64_t getOptionalI64(std::optional arrayAttr, size_t index, int64_t defaultValue) { - return arrayAttr ? getI64(*arrayAttr, index) : defaultValue; -} - static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) { auto tileType = cast(tile.getType()); Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType()); - - SmallVector offsets(tileType.getRank(), rewriter.getIndexAttr(0)); - SmallVector sizes; - sizes.reserve(tileType.getRank()); - for (int64_t dimSize : tileType.getShape()) - sizes.push_back(rewriter.getIndexAttr(dimSize)); - SmallVector strides(tileType.getRank(), rewriter.getIndexAttr(1)); - - return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides); + return insertStaticSlice(rewriter, loc, tile, empty, getZeroOffsets(rewriter, tileType.getRank())); } static Value @@ -197,12 +179,12 @@ struct PoolToSpatialComputeBase : public OpConversionPattern { const int64_t inputWidth = xType.getDimSize(3); const int64_t outputHeight = outType.getDimSize(2); const int64_t outputWidth = outType.getDimSize(3); - const int64_t kernelHeight = getI64(kernelAttr, 0); - const int64_t kernelWidth = getI64(kernelAttr, 1); - const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1); - const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1); - const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1); - const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1); + const int64_t kernelHeight = getI64Attr(kernelAttr, 0); + const int64_t kernelWidth = getI64Attr(kernelAttr, 1); + const int64_t strideHeight = getOptionalI64Attr(poolOp.getStrides(), 0, 1); + const int64_t strideWidth = getOptionalI64Attr(poolOp.getStrides(), 1, 1); + const int64_t dilationHeight = getOptionalI64Attr(poolOp.getDilations(), 0, 1); + const int64_t dilationWidth = getOptionalI64Attr(poolOp.getDilations(), 1, 1); int64_t padTop = 0; int64_t padLeft = 0; @@ -212,10 +194,10 @@ struct PoolToSpatialComputeBase : public OpConversionPattern { if (auto padsAttr = poolOp.getPads()) { if (padsAttr->size() != 4) return rewriter.notifyMatchFailure(poolOp, "pads must have four elements."); - padTop = getI64(*padsAttr, 0); - padLeft = getI64(*padsAttr, 1); - padBottom = getI64(*padsAttr, 2); - padRight = getI64(*padsAttr, 3); + padTop = getI64Attr(*padsAttr, 0); + padLeft = getI64Attr(*padsAttr, 1); + padBottom = getI64Attr(*padsAttr, 2); + padRight = getI64Attr(*padsAttr, 3); } else { StringRef autoPad = poolOp.getAutoPad(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp index 9ebdae8..bc95983 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp @@ -13,16 +13,6 @@ using namespace mlir; namespace onnx_mlir { namespace { -static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; } - -static SmallVector permuteShape(ArrayRef shape, ArrayRef permutation) { - SmallVector permutedShape; - permutedShape.reserve(permutation.size()); - for (int64_t axis : permutation) - permutedShape.push_back(shape[axis]); - return permutedShape; -} - static Value buildLoopSoftmaxSlice(Value input, Value accumulator, RankedTensorType inputType, @@ -36,7 +26,7 @@ static Value buildLoopSoftmaxSlice(Value input, SmallVector offsets; SmallVector sizes; - SmallVector strides(rank, rewriter.getIndexAttr(1)); + SmallVector strides = getUnitStrides(rewriter, rank); offsets.reserve(rank); sizes.reserve(rank); @@ -110,44 +100,31 @@ struct SoftmaxToSpatialCompute : OpConversionPattern { if (!inputType || !inputType.hasStaticShape()) return failure(); - int64_t axis = normalizeAxis(softmaxOp.getAxis(), inputType.getRank()); - if (axis < 0 || axis >= inputType.getRank()) + auto axis = normalizeAxisChecked(softmaxOp.getAxis(), inputType.getRank()); + if (failed(axis)) return failure(); Value input = adaptor.getInput(); Value result; - if (axis == inputType.getRank() - 1) { + if (*axis == inputType.getRank() - 1) { result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc()); } else { SmallVector permutation; permutation.reserve(inputType.getRank()); for (int64_t dim = 0; dim < inputType.getRank(); ++dim) - if (dim != axis) + if (dim != *axis) permutation.push_back(dim); - permutation.push_back(axis); - - SmallVector inversePermutation(inputType.getRank()); - for (auto [newIndex, oldIndex] : llvm::enumerate(permutation)) - inversePermutation[oldIndex] = static_cast(newIndex); + permutation.push_back(*axis); + SmallVector inversePermutation = invertPermutation(permutation); auto transposedType = RankedTensorType::get( permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding()); - auto preTransposeCompute = - createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) { - Value transposed = ONNXTransposeOp::create( - rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation)); - spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed); - }); - Value transposedInput = preTransposeCompute.getResult(0); + Value transposedInput = + transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc()); Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc()); - auto postTransposeCompute = - createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) { - Value transposed = ONNXTransposeOp::create( - rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation)); - spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed); - }); - result = postTransposeCompute.getResult(0); + result = transposeMaybeInCompute( + transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc()); } rewriter.replaceOp(softmaxOp, result); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Post.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Post.cpp index ba556c3..d66524b 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Post.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Post.cpp @@ -36,6 +36,14 @@ static bool isDirectConstantValue(Value value) { return isa_and_nonnull(value.getDefiningOp()); } +struct PromotedOperands { + SmallVector promoteInput; + SmallVector newWeights; + SmallVector newInputs; + SmallVector newInputTypes; + SmallVector newInputLocs; +}; + template static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) { for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { @@ -48,60 +56,91 @@ static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) { return false; } +template +static FailureOr computePromotedOperands(ComputeOpTy compute) { + PromotedOperands promoted; + promoted.promoteInput.assign(compute.getInputs().size(), false); + promoted.newWeights.append(compute.getWeights().begin(), compute.getWeights().end()); + promoted.newWeights.reserve(compute.getWeights().size() + compute.getInputs().size()); + promoted.newInputs.reserve(compute.getInputs().size()); + promoted.newInputTypes.reserve(compute.getInputs().size()); + promoted.newInputLocs.reserve(compute.getInputs().size()); + + bool needsRewrite = false; + for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { + if (!isWeightLikeComputeOperand(input)) + goto keep_input; + if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx))) + goto keep_input; + promoted.promoteInput[inputIdx] = true; + promoted.newWeights.push_back(input); + needsRewrite = true; + continue; + + keep_input: + promoted.newInputs.push_back(input); + promoted.newInputTypes.push_back(input.getType()); + promoted.newInputLocs.push_back(input.getLoc()); + } + + if (!needsRewrite) + return failure(); + return promoted; +} + +template +static LogicalResult mapPromotedInputArguments(ComputeOpTy compute, + const PromotedOperands& promoted, + IRRewriter& bodyRewriter, + IRMapping& mapper, + std::function(size_t)> getNewInputArg, + PatternRewriter& rewriter) { + size_t newInputIdx = 0; + for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) { + auto oldArg = compute.getInputArgument(oldInputIdx); + if (!oldArg) + return rewriter.notifyMatchFailure(compute, "missing input block argument during rewrite"); + if (!promoted.promoteInput[oldInputIdx]) { + auto newInputArg = getNewInputArg(newInputIdx++); + if (!newInputArg) + return rewriter.notifyMatchFailure(compute, "missing rewritten input block argument"); + mapper.map(*oldArg, *newInputArg); + continue; + } + + auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper); + if (failed(clonedValue)) + return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand"); + mapper.map(*oldArg, *clonedValue); + } + return success(); +} + // Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs. struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override { - SmallVector promoteInput(compute.getInputs().size(), false); - bool needsRewrite = false; - Block& oldBlock = compute.getBody().front(); - for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { - if (!isWeightLikeComputeOperand(input)) - continue; - if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx))) - continue; - promoteInput[inputIdx] = true; - needsRewrite = true; - } - if (!needsRewrite) + auto promoted = computePromotedOperands(compute); + if (failed(promoted)) return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote"); + Block& oldBlock = compute.getBody().front(); rewriter.setInsertionPointAfter(compute); - - SmallVector newWeights(compute.getWeights().begin(), compute.getWeights().end()); - SmallVector newInputs; - SmallVector newInputTypes; - SmallVector newInputLocs; - newWeights.reserve(compute.getWeights().size() + compute.getInputs().size()); - newInputs.reserve(compute.getInputs().size()); - newInputTypes.reserve(compute.getInputs().size()); - newInputLocs.reserve(compute.getInputs().size()); - - for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { - if (promoteInput[inputIdx]) { - newWeights.push_back(input); - continue; - } - newInputs.push_back(input); - newInputTypes.push_back(input.getType()); - newInputLocs.push_back(input.getLoc()); - } - auto newCompute = - spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs); + spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs); SmallVector newBlockArgTypes; SmallVector newBlockArgLocs; - for (Value weight : newWeights) { + for (Value weight : promoted->newWeights) { newBlockArgTypes.push_back(weight.getType()); newBlockArgLocs.push_back(weight.getLoc()); } - llvm::append_range(newBlockArgTypes, newInputTypes); - llvm::append_range(newBlockArgLocs, newInputLocs); + llvm::append_range(newBlockArgTypes, promoted->newInputTypes); + llvm::append_range(newBlockArgLocs, promoted->newInputLocs); auto* newBlock = rewriter.createBlock( &newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs); newCompute.getProperties().setOperandSegmentSizes( - {static_cast(newWeights.size()), static_cast(newInputs.size())}); + {static_cast(promoted->newWeights.size()), static_cast(promoted->newInputs.size())}); rewriter.setInsertionPointToStart(newBlock); IRRewriter bodyRewriter(rewriter.getContext()); @@ -115,24 +154,9 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override { - SmallVector promoteInput(compute.getInputs().size(), false); - bool needsRewrite = false; - Block& oldBlock = compute.getBody().front(); - for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { - if (!isWeightLikeComputeOperand(input)) - continue; - if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx))) - continue; - promoteInput[inputIdx] = true; - needsRewrite = true; - } - if (!needsRewrite) + auto promoted = computePromotedOperands(compute); + if (failed(promoted)) return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote"); + Block& oldBlock = compute.getBody().front(); rewriter.setInsertionPointAfter(compute); - SmallVector newWeights(compute.getWeights().begin(), compute.getWeights().end()); - SmallVector newInputs; - SmallVector newInputTypes; - SmallVector newInputLocs; - newWeights.reserve(compute.getWeights().size() + compute.getInputs().size()); - newInputs.reserve(compute.getInputs().size()); - newInputTypes.reserve(compute.getInputs().size()); - newInputLocs.reserve(compute.getInputs().size()); - - for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) { - if (promoteInput[inputIdx]) { - newWeights.push_back(input); - continue; - } - newInputs.push_back(input); - newInputTypes.push_back(input.getType()); - newInputLocs.push_back(input.getLoc()); - } - auto newCompute = spatial::SpatComputeBatch::create(rewriter, compute.getLoc(), compute.getResultTypes(), rewriter.getI32IntegerAttr(static_cast(compute.getLaneCount())), - newWeights, - newInputs); + promoted->newWeights, + promoted->newInputs); auto laneArg = compute.getLaneArgument(); if (!laneArg) return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument"); SmallVector newBlockArgTypes; SmallVector newBlockArgLocs; - newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults()); - newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults()); + newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size() + compute.getNumResults()); + newBlockArgLocs.reserve(1 + promoted->newWeights.size() + promoted->newInputLocs.size() + compute.getNumResults()); newBlockArgTypes.push_back(laneArg->getType()); newBlockArgLocs.push_back(laneArg->getLoc()); - for (Value weight : newWeights) { + for (Value weight : promoted->newWeights) { newBlockArgTypes.push_back(weight.getType()); newBlockArgLocs.push_back(weight.getLoc()); } - llvm::append_range(newBlockArgTypes, newInputTypes); - llvm::append_range(newBlockArgLocs, newInputLocs); + llvm::append_range(newBlockArgTypes, promoted->newInputTypes); + llvm::append_range(newBlockArgLocs, promoted->newInputLocs); for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) { auto outputArg = compute.getOutputArgument(resultIndex); if (!outputArg) @@ -224,7 +220,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern(newWeights.size()), static_cast(newInputs.size())}); + {static_cast(promoted->newWeights.size()), static_cast(promoted->newInputs.size())}); rewriter.setInsertionPointToStart(newBlock); IRRewriter bodyRewriter(rewriter.getContext()); @@ -242,29 +238,15 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern(0, compute.getNumResults())) { auto outputArg = compute.getOutputArgument(resultIndex); if (!outputArg) return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite"); - mapper.map(*outputArg, newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex)); + mapper.map(*outputArg, + newBlock->getArgument(1 + promoted->newWeights.size() + promoted->newInputs.size() + resultIndex)); } for (Operation& op : oldBlock) diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp index e388b83..79ae2c5 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp @@ -15,24 +15,6 @@ using namespace mlir; namespace onnx_mlir { namespace { -static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; } - -static int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; } - -static Value -extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) { - auto inputType = cast(input.getType()); - SmallVector offsets(inputType.getRank(), rewriter.getIndexAttr(0)); - SmallVector sizes; - SmallVector strides(inputType.getRank(), rewriter.getIndexAttr(1)); - sizes.reserve(inputType.getRank()); - for (int64_t dim : inputType.getShape()) - sizes.push_back(rewriter.getIndexAttr(dim)); - offsets[axis] = rewriter.getIndexAttr(offset); - sizes[axis] = rewriter.getIndexAttr(1); - return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides); -} - static Value concatGatherSlices(Value data, int64_t axis, ArrayRef indices, @@ -45,7 +27,7 @@ static Value concatGatherSlices(Value data, int64_t normalizedIndex = normalizeIndex(index, axisDim); if (normalizedIndex < 0 || normalizedIndex >= axisDim) return {}; - slices.push_back(extractSliceAt(data, axis, normalizedIndex, rewriter, loc)); + slices.push_back(extractAxisSlice(rewriter, loc, data, axis, normalizedIndex, /*size=*/1)); } if (slices.empty()) return {}; @@ -96,11 +78,11 @@ struct Gather : OpConversionPattern { return failure(); int64_t rank = dataType.getRank(); - int64_t axis = normalizeAxis(gatherOp.getAxis(), rank); - if (axis < 0 || axis >= rank) + auto axis = normalizeAxisChecked(gatherOp.getAxis(), rank); + if (failed(axis)) return failure(); - int64_t axisDim = dataType.getShape()[axis]; + int64_t axisDim = dataType.getShape()[*axis]; if (axisDim <= 0) return failure(); @@ -116,7 +98,7 @@ struct Gather : OpConversionPattern { [&](Value data) -> LogicalResult { Value result; if (indicesType.getRank() == 1) { - result = concatGatherSlices(data, axis, flatIndices, axisDim, rewriter, loc); + result = concatGatherSlices(data, *axis, flatIndices, axisDim, rewriter, loc); } else if (indicesType.getRank() == 2) { int64_t rowCount = indicesType.getShape()[0]; @@ -125,12 +107,13 @@ struct Gather : OpConversionPattern { rows.reserve(rowCount); for (int64_t row = 0; row < rowCount; ++row) { ArrayRef rowIndices(flatIndices.data() + row * rowWidth, rowWidth); - Value gatheredRow = concatGatherSlices(data, axis, rowIndices, axisDim, rewriter, loc); + Value gatheredRow = + concatGatherSlices(data, *axis, rowIndices, axisDim, rewriter, loc); if (!gatheredRow) return failure(); - rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc)); + rows.push_back(addLeadingGatherDim(gatheredRow, *axis, rewriter, loc)); } - result = createSpatConcat(rewriter, loc, /*axis=*/axis, rows); + result = createSpatConcat(rewriter, loc, /*axis=*/*axis, rows); } else { return failure(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp index a766982..d75bc9f 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp @@ -14,10 +14,6 @@ using namespace mlir; namespace onnx_mlir { namespace { -static bool haveStaticPositiveShape(ArrayRef shape) { - return llvm::all_of(shape, [](int64_t dim) { return dim > 0; }); -} - static bool inferCollapseReassociation(ArrayRef sourceShape, ArrayRef resultShape, SmallVector& reassociation) { @@ -106,7 +102,7 @@ struct Reshape : OpConversionPattern { auto resultType = dyn_cast(reshapeOp.getReshaped().getType()); if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()) return failure(); - if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape())) + if (!hasStaticPositiveShape(sourceType) || !hasStaticPositiveShape(resultType)) return failure(); if (sourceType == resultType) { @@ -115,17 +111,8 @@ struct Reshape : OpConversionPattern { } auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult { - if (isCompileTimeComputable(adaptor.getData())) { - rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData())); - return success(); - } - - auto computeOp = createSpatCompute<1>( - rewriter, reshapeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) { - Value reshaped = buildReshape(data); - spatial::SpatYieldOp::create(rewriter, reshapeOp.getLoc(), reshaped); - }); - rewriter.replaceOp(reshapeOp, computeOp.getResults()); + Value reshaped = materializeOrComputeUnary(adaptor.getData(), resultType, rewriter, reshapeOp.getLoc(), buildReshape); + rewriter.replaceOp(reshapeOp, reshaped); return success(); }; diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp index ffdabfe..db82465 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp @@ -12,25 +12,6 @@ using namespace mlir; namespace onnx_mlir { namespace { -static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; } - -static Value extractSliceAt( - Value input, int64_t axis, int64_t offset, int64_t size, ConversionPatternRewriter& rewriter, Location loc) { - auto inputType = cast(input.getType()); - SmallVector offsets(inputType.getRank(), rewriter.getIndexAttr(0)); - SmallVector sizes; - SmallVector strides(inputType.getRank(), rewriter.getIndexAttr(1)); - sizes.reserve(inputType.getRank()); - for (int64_t dim : inputType.getShape()) - sizes.push_back(rewriter.getIndexAttr(dim)); - offsets[axis] = rewriter.getIndexAttr(offset); - sizes[axis] = rewriter.getIndexAttr(size); - SmallVector resultShape(inputType.getShape()); - resultShape[axis] = size; - auto resultType = RankedTensorType::get(resultShape, inputType.getElementType()); - return tensor::ExtractSliceOp::create(rewriter, loc, resultType, input, offsets, sizes, strides); -} - struct Split : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -41,8 +22,8 @@ struct Split : OpConversionPattern { return failure(); int64_t rank = inputType.getRank(); - int64_t axis = normalizeAxis(splitOp.getAxis(), rank); - if (axis < 0 || axis >= rank) + auto axis = normalizeAxisChecked(splitOp.getAxis(), rank); + if (failed(axis)) return failure(); SmallVector outputs; @@ -58,12 +39,13 @@ struct Split : OpConversionPattern { if (!resultType || !resultType.hasStaticShape()) return failure(); resultTypes.push_back(resultType); - sliceSizes.push_back(resultType.getShape()[axis]); + sliceSizes.push_back(resultType.getShape()[*axis]); } if (isCompileTimeComputable(adaptor.getInput())) { for (int64_t sliceSize : sliceSizes) { - outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc())); + outputs.push_back( + extractAxisSlice(rewriter, splitOp.getLoc(), adaptor.getInput(), *axis, offset, sliceSize)); offset += sliceSize; } rewriter.replaceOp(splitOp, outputs); @@ -76,7 +58,8 @@ struct Split : OpConversionPattern { runtimeOutputs.reserve(resultTypes.size()); int64_t runtimeOffset = 0; for (int64_t sliceSize : sliceSizes) { - runtimeOutputs.push_back(extractSliceAt(input, axis, runtimeOffset, sliceSize, rewriter, splitOp.getLoc())); + runtimeOutputs.push_back( + extractAxisSlice(rewriter, splitOp.getLoc(), input, *axis, runtimeOffset, sliceSize)); runtimeOffset += sliceSize; } spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), runtimeOutputs); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Transpose.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Transpose.cpp index 07c524d..7085589 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Transpose.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Transpose.cpp @@ -4,6 +4,7 @@ #include "llvm/ADT/SmallVector.h" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -29,22 +30,6 @@ static Value createTransposeInit(Value input, return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult(); } -static SmallVector getTransposePermutation(ONNXTransposeOp transposeOp) { - auto inputType = cast(transposeOp.getData().getType()); - SmallVector permutation; - if (auto permAttr = transposeOp.getPermAttr()) { - permutation.reserve(permAttr.size()); - for (IntegerAttr attr : permAttr.getAsRange()) - permutation.push_back(attr.getInt()); - return permutation; - } - - permutation.reserve(inputType.getRank()); - for (int64_t dim = inputType.getRank() - 1; dim >= 0; --dim) - permutation.push_back(dim); - return permutation; -} - struct TransposeToLinalgTranspose : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -56,10 +41,12 @@ struct TransposeToLinalgTranspose : OpConversionPattern { if (!inputType || !resultType) return failure(); - SmallVector permutation = getTransposePermutation(transposeOp); - Value init = createTransposeInit(adaptor.getData(), resultType, permutation, rewriter, transposeOp.getLoc()); + auto permutation = getTransposePermutationChecked(transposeOp.getPermAttr(), inputType.getRank()); + if (failed(permutation)) + return failure(); + Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc()); Value transposed = - linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, permutation) + linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation) .getResult()[0]; rewriter.replaceOp(transposeOp, transposed); return success(); diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp index 2225b4c..a8c0fa5 100644 --- a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp @@ -40,7 +40,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite continue; if (auto constantOp = dyn_cast(definingOp)) { - mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder)); + mapping.map(operand, getOrCreateHostConstantLike(constantFolder, constantOp)); continue; } @@ -218,7 +218,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp continue; if (auto constantOp = input.getDefiningOp()) { - blockArg->replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder)); + blockArg->replaceAllUsesWith(getOrCreateHostConstantLike(constantFolder, constantOp)); continue; } @@ -230,8 +230,8 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp PimMemCopyHostToDevOp::create(rewriter, loc, outputBuffer.getType(), - getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder), - getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder), + getOrCreateHostIndexConstant(constantFolder, outputBuffer.getOperation(), 0), + getOrCreateHostIndexConstant(constantFolder, outputBuffer.getOperation(), 0), outputBuffer, input, getTensorSizeInBytesAttr(rewriter, input)) diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp index a1376f5..18a5e1d 100644 --- a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp +++ b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp @@ -326,7 +326,7 @@ cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewrite continue; if (auto constantOp = dyn_cast(definingOp)) { - mapping.map(operand, getOrCreateHostConstantLike(constantOp, constantFolder)); + mapping.map(operand, getOrCreateHostConstantLike(constantFolder, constantOp)); continue; } @@ -370,8 +370,8 @@ static Value emitHostCopy(IRRewriter& rewriter, OperationFolder& constantFolder) { Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp(); assert(anchorOp && "expected a concrete op anchor for return-path host copy constants"); - Value hostTargetOffsetValue = getOrCreateHostIndexConstant(anchorOp, hostTargetOffset, constantFolder); - Value deviceSourceOffsetValue = getOrCreateHostIndexConstant(anchorOp, deviceSourceOffset, constantFolder); + Value hostTargetOffsetValue = getOrCreateHostIndexConstant(constantFolder, anchorOp, hostTargetOffset); + Value deviceSourceOffsetValue = getOrCreateHostIndexConstant(constantFolder, anchorOp, deviceSourceOffset); return PimMemCopyDevToHostOp::create(rewriter, loc, outputTensor.getType(), diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index b9de797..8022cc4 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -81,7 +81,7 @@ static Value createZeroedDeviceHVector(IRRewriter& rewriter, auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, tensorType); auto zeroGlobal = getOrCreateZeroGlobal(rewriter, loc, tensorType); auto zeroValue = memref::GetGlobalOp::create(rewriter, loc, zeroGlobal.getType(), zeroGlobal.getName()); - auto zeroIndex = getOrCreateHostIndexConstant(outputBuffer.getOperation(), 0, constantFolder); + auto zeroIndex = getOrCreateHostIndexConstant(constantFolder, outputBuffer.getOperation(), 0); auto sizeAttr = rewriter.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(tensorType))); if (outputBuffer->getParentOfType()) @@ -333,9 +333,9 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables( rewriter, loc, tensorType, - getOrCreateHostIndexConstant(deviceTensor.getOperation(), 0, constantFolder), - getOrCreateHostIndexConstant( - deviceTensor.getOperation(), static_cast(elementsOffset * elementByteSize), constantFolder), + getOrCreateHostIndexConstant(constantFolder, deviceTensor.getOperation(), 0), + getOrCreateHostIndexConstant(constantFolder, + deviceTensor.getOperation(), static_cast(elementsOffset * elementByteSize) ), deviceTensor, inputTensor, rewriter.getI32IntegerAttr(static_cast(tensorType.getNumElements() * elementByteSize))); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index ecda21c..bda321d 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -619,7 +619,7 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali } Value createIndexConstant(MaterializerState& state, Operation* anchor, int64_t value) { - return getOrCreateHostIndexConstant(anchor, value, state.constantFolder); + return getOrCreateHostIndexConstant(state.constantFolder, anchor, value); } // ----------------------------------------------------------------------------- @@ -939,7 +939,7 @@ Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, Arr auto type = RankedTensorType::get({static_cast(values.size())}, state.rewriter.getIndexType()); auto attr = DenseIntElementsAttr::get(type, elements); - return getOrCreateHostConstant(anchor, attr, type, state.constantFolder); + return getOrCreateHostConstant(state.constantFolder, anchor, attr, type); } bool allEqual(ArrayRef values) { diff --git a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp index 490ca59..34decea 100644 --- a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp +++ b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp @@ -113,8 +113,8 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, rewriter, op->getLoc(), originalType, - getOrCreateHostIndexConstant(op, 0, constantFolder), - getOrCreateHostIndexConstant(op, static_cast(resolvedAddress->byteOffset), constantFolder), + getOrCreateHostIndexConstant(constantFolder, op, 0), + getOrCreateHostIndexConstant(constantFolder, op, static_cast(resolvedAddress->byteOffset) ), deviceDst, getGlobalOp.getResult(), rewriter.getI32IntegerAttr(static_cast(totalBytes)))