Transpose and Refactor of Patterns
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -0,0 +1,18 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateGeneratedConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<removeLRN>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -0,0 +1,295 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isWeightMaterializationHelperUser(Operation* op) {
|
||||
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(op);
|
||||
}
|
||||
|
||||
static bool canPromoteInputBlockArgument(BlockArgument arg) {
|
||||
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
|
||||
}
|
||||
|
||||
static bool canPromoteInputBlockArgument(std::optional<BlockArgument> arg) {
|
||||
return arg && canPromoteInputBlockArgument(*arg);
|
||||
}
|
||||
|
||||
static bool isDirectConstantValue(Value value) {
|
||||
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
continue;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
||||
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
||||
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
|
||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||
bool needsRewrite = false;
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
}
|
||||
if (!needsRewrite)
|
||||
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
|
||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||
SmallVector<Value> newInputs;
|
||||
SmallVector<Type> newInputTypes;
|
||||
SmallVector<Location> newInputLocs;
|
||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||
newInputs.reserve(compute.getInputs().size());
|
||||
newInputTypes.reserve(compute.getInputs().size());
|
||||
newInputLocs.reserve(compute.getInputs().size());
|
||||
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (promoteInput[inputIdx]) {
|
||||
newWeights.push_back(input);
|
||||
continue;
|
||||
}
|
||||
newInputs.push_back(input);
|
||||
newInputTypes.push_back(input.getType());
|
||||
newInputLocs.push_back(input.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
for (Value weight : newWeights) {
|
||||
newBlockArgTypes.push_back(weight.getType());
|
||||
newBlockArgLocs.push_back(weight.getLoc());
|
||||
}
|
||||
llvm::append_range(newBlockArgTypes, newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, newInputLocs);
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRRewriter bodyRewriter(rewriter.getContext());
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
||||
auto oldWeightArg = compute.getWeightArgument(weightIndex);
|
||||
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
|
||||
if (!oldWeightArg || !newWeightArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
|
||||
mapper.map(*oldWeightArg, *newWeightArg);
|
||||
}
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
auto oldArg = compute.getInputArgument(oldInputIdx);
|
||||
if (!oldArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute input block argument during rewrite");
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
|
||||
if (!newInputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute input block argument");
|
||||
mapper.map(*oldArg, *newInputArg);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
||||
mapper.map(*oldArg, *clonedValue);
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator())
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
SmallVector<Value> newYieldOperands;
|
||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
||||
for (Value operand : oldYield.getOutputs()) {
|
||||
auto mapped = mapper.lookupOrNull(operand);
|
||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
||||
|
||||
rewriter.replaceOp(compute, newCompute.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
|
||||
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
|
||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||
bool needsRewrite = false;
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
}
|
||||
if (!needsRewrite)
|
||||
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
|
||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||
SmallVector<Value> newInputs;
|
||||
SmallVector<Type> newInputTypes;
|
||||
SmallVector<Location> newInputLocs;
|
||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||
newInputs.reserve(compute.getInputs().size());
|
||||
newInputTypes.reserve(compute.getInputs().size());
|
||||
newInputLocs.reserve(compute.getInputs().size());
|
||||
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (promoteInput[inputIdx]) {
|
||||
newWeights.push_back(input);
|
||||
continue;
|
||||
}
|
||||
newInputs.push_back(input);
|
||||
newInputTypes.push_back(input.getType());
|
||||
newInputLocs.push_back(input.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatComputeBatch::create(rewriter,
|
||||
compute.getLoc(),
|
||||
compute.getResultTypes(),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
||||
newWeights,
|
||||
newInputs);
|
||||
auto laneArg = compute.getLaneArgument();
|
||||
if (!laneArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults());
|
||||
newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults());
|
||||
newBlockArgTypes.push_back(laneArg->getType());
|
||||
newBlockArgLocs.push_back(laneArg->getLoc());
|
||||
for (Value weight : newWeights) {
|
||||
newBlockArgTypes.push_back(weight.getType());
|
||||
newBlockArgLocs.push_back(weight.getLoc());
|
||||
}
|
||||
llvm::append_range(newBlockArgTypes, newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, newInputLocs);
|
||||
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
|
||||
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||
if (!outputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument");
|
||||
newBlockArgTypes.push_back(resultType);
|
||||
newBlockArgLocs.push_back(outputArg->getLoc());
|
||||
}
|
||||
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRRewriter bodyRewriter(rewriter.getContext());
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
auto newLaneArg = newCompute.getLaneArgument();
|
||||
if (!newLaneArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch lane block argument");
|
||||
mapper.map(*laneArg, *newLaneArg);
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
||||
auto oldWeightArg = compute.getWeightArgument(weightIndex);
|
||||
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
|
||||
if (!oldWeightArg || !newWeightArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
|
||||
mapper.map(*oldWeightArg, *newWeightArg);
|
||||
}
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
auto oldArg = compute.getInputArgument(oldInputIdx);
|
||||
if (!oldArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch input block argument during rewrite");
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
|
||||
if (!newInputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch input block argument");
|
||||
mapper.map(*oldArg, *newInputArg);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
|
||||
mapper.map(*oldArg, *clonedValue);
|
||||
}
|
||||
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
|
||||
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||
if (!outputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
|
||||
mapper.map(*outputArg, newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
rewriter.replaceOp(compute, newCompute.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateWeightPromotionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
|
||||
}
|
||||
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) {
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
|
||||
markWeightAlways(constantOp);
|
||||
});
|
||||
}
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,22 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
|
||||
patterns.add<onnxToArithConstant>(ctx);
|
||||
patterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||
patterns.add<convAddToConvWithBiasRight>(ctx);
|
||||
patterns.add<matMulAddToGemm>(ctx);
|
||||
patterns.add<removeFlattenSameShape>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static Value createTransposeInit(Value input,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<int64_t> permutation,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(resultType.getRank());
|
||||
for (auto [resultDim, sourceDim] : llvm::zip_equal(resultType.getShape(), permutation)) {
|
||||
if (!ShapedType::isDynamic(resultDim)) {
|
||||
sizes.push_back(rewriter.getIndexAttr(resultDim));
|
||||
continue;
|
||||
}
|
||||
sizes.push_back(tensor::DimOp::create(rewriter, loc, input, sourceDim).getResult());
|
||||
}
|
||||
return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult();
|
||||
}
|
||||
|
||||
static SmallVector<int64_t> getTransposePermutation(ONNXTransposeOp transposeOp) {
|
||||
auto inputType = cast<RankedTensorType>(transposeOp.getData().getType());
|
||||
SmallVector<int64_t> permutation;
|
||||
if (auto permAttr = transposeOp.getPermAttr()) {
|
||||
permutation.reserve(permAttr.size());
|
||||
for (IntegerAttr attr : permAttr.getAsRange<IntegerAttr>())
|
||||
permutation.push_back(attr.getInt());
|
||||
return permutation;
|
||||
}
|
||||
|
||||
permutation.reserve(inputType.getRank());
|
||||
for (int64_t dim = inputType.getRank() - 1; dim >= 0; --dim)
|
||||
permutation.push_back(dim);
|
||||
return permutation;
|
||||
}
|
||||
|
||||
struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXTransposeOp transposeOp,
|
||||
ONNXTransposeOpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(transposeOp.getResult().getType());
|
||||
if (!inputType || !resultType)
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> permutation = getTransposePermutation(transposeOp);
|
||||
Value init = createTransposeInit(adaptor.getData(), resultType, permutation, rewriter, transposeOp.getLoc());
|
||||
Value transposed =
|
||||
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, permutation)
|
||||
.getResult()[0];
|
||||
rewriter.replaceOp(transposeOp, transposed);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateTransposePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<TransposeToLinalgTranspose>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user