Transpose and Refactor of Patterns
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-29 13:23:31 +02:00
parent 82b44a6387
commit 832bd7f1f7
37 changed files with 285 additions and 153 deletions
+1
View File
@@ -18,6 +18,7 @@ add_pim_library(OMPimCommon
${PIM_PUBLIC_INCLUDE_DIRS}
LINK_LIBS PUBLIC
MLIRLinalgDialect
onnx
SpatialOps
PimOps
+3 -2
View File
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -131,8 +132,8 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(user))
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
if (auto transposeOp = mlir::dyn_cast<mlir::linalg::TransposeOp>(user))
return transposeOp.getInput() == currentValue && self(transposeOp.getResult()[0], self);
return false;
});
+1 -1
View File
@@ -18,7 +18,7 @@ void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
mlir::OpPrintingFlags flags;
flags.elideLargeElementsAttrs();
flags.elideLargeElementsAttrs().enableDebugInfo(true,false);
moduleOp.print(os, flags);
os.flush();
file.close();
@@ -3,11 +3,12 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen)
add_pim_library(OMONNXToSpatial
ConversionPatterns.cpp
Patterns.cpp
CompileTime.cpp
ONNXToSpatialVerifier.cpp
PrePatterns.cpp
PostPatterns.cpp
Patterns/Pre.cpp
Patterns/Post.cpp
Patterns/GeneratedConversion.cpp
Patterns/Math/Conv.cpp
Patterns/Math/Elementwise.cpp
Patterns/Math/Gemm.cpp
@@ -22,6 +23,7 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Resize.cpp
Patterns/Tensor/Reshape.cpp
Patterns/Tensor/Split.cpp
Patterns/Tensor/Transpose.cpp
ONNXToSpatialPass.cpp
Common/ComputeRegionBuilder.cpp
Common/ShapeTilingUtils.cpp
@@ -33,6 +35,7 @@ add_pim_library(OMONNXToSpatial
ONNXToSpatialIncGen
LINK_LIBS PUBLIC
MLIRLinalgDialect
MLIRSCFDialect
MLIRTosaDialect
OMCompilerOptions
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
@@ -43,8 +44,8 @@ bool isWeightLikeComputeOperand(Value value) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
value = transposeOp.getInput();
continue;
}
@@ -80,7 +81,7 @@ FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr
return referencedValue.getResult();
}
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(definingOp))
return failure();
IRMapping localMapper;
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -171,6 +172,16 @@ static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm:
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
}
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getInput(), visited);
if (!inputAttr)
return nullptr;
SmallVector<int64_t> perm(transposeOp.getPermutation().begin(), transposeOp.getPermutation().end());
auto transposedAttr = transposeDenseElements(inputAttr, perm);
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
if (!inputAttr)
@@ -226,6 +237,9 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength);
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(op))
return getCompileTimeSourceImpl(transposeOp.getInput().getDefiningOp(), visited, chainLength);
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength);
@@ -1,6 +1,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
@@ -14,10 +15,8 @@
#include "Common/Common.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -86,30 +85,6 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
returnOp.setOperand(index, computeResult);
}
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
Block& entryBlock = funcOp.getFunctionBody().front();
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
if (!transposeOp || isCompileTimeOp(transposeOp))
continue;
// Transpose stays globally legal because constant/view-only cases are
// allowed on the host. Any residual runtime transpose must be sunk into
// spat.compute before the host legality check.
auto resultType = transposeOp.getResult().getType();
rewriter.setInsertionPoint(transposeOp);
auto computeOp = createSpatCompute<1>(
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
Value transposed =
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
});
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
}
}
void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext();
@@ -117,6 +92,7 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget preTarget(*ctx);
preTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
@@ -156,11 +132,13 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget target(*ctx);
target.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
target.addIllegalOp<ONNXMatMulOp>();
target.addIllegalOp<ONNXTransposeOp>();
target.addIllegalOp<ONNXAddOp>();
target.addIllegalOp<ONNXDivOp>();
target.addIllegalOp<ONNXMulOp>();
@@ -187,9 +165,14 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
RewritePatternSet transposePatterns(ctx);
populateTransposePatterns(transposePatterns, ctx);
walkAndApplyPatterns(moduleOp, std::move(transposePatterns));
ConversionTarget earlyPostTarget(*ctx);
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
@@ -205,6 +188,7 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget postTarget(*ctx);
postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
@@ -222,8 +206,6 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
wrapTopLevelRuntimeTransposes(*entryFunc);
if (failed(verifyONNXToSpatial(*entryFunc))) {
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
signalPassFailure();
@@ -1,19 +1,16 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
} // namespace
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<removeLRN>(ctx);
void populatePrePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateGeneratedPrePatterns(patterns, ctx);
}
void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateGeneratedConversionPatterns(patterns, ctx);
populateElementwisePatterns(patterns, ctx);
populateGemmPatterns(patterns, ctx);
populateConvPatterns(patterns, ctx);
@@ -27,6 +24,11 @@ void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRCon
populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx);
populateTransposePatterns(patterns, ctx);
}
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateWeightPromotionPatterns(patterns, ctx);
}
} // namespace onnx_mlir
@@ -1,38 +1,39 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGeneratedConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateWeightPromotionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReduceMeanPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSigmoidPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSoftmaxPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
bool requiresPostRewrite(spatial::SpatCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -0,0 +1,18 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
} // namespace
void populateGeneratedConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<removeLRN>(ctx);
}
} // namespace onnx_mlir
@@ -7,7 +7,7 @@
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -10,7 +10,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -7,7 +7,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -4,7 +4,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,5 +1,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
@@ -9,7 +10,7 @@
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -20,7 +21,7 @@ namespace onnx_mlir {
namespace {
static bool isWeightMaterializationHelperUser(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(op);
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(op);
}
static bool canPromoteInputBlockArgument(BlockArgument arg) {
@@ -276,7 +277,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
} // namespace
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
void populateWeightPromotionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
}
@@ -1,6 +1,5 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
using namespace mlir;
@@ -12,7 +11,7 @@ namespace {
} // namespace
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<onnxToArithConstant>(ctx);
patterns.add<convAddToConvWithBiasLeft>(ctx);
patterns.add<convAddToConvWithBiasRight>(ctx);
@@ -6,7 +6,7 @@
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -5,7 +5,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -6,7 +6,7 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -3,7 +3,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -0,0 +1,75 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static Value createTransposeInit(Value input,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
SmallVector<OpFoldResult> sizes;
sizes.reserve(resultType.getRank());
for (auto [resultDim, sourceDim] : llvm::zip_equal(resultType.getShape(), permutation)) {
if (!ShapedType::isDynamic(resultDim)) {
sizes.push_back(rewriter.getIndexAttr(resultDim));
continue;
}
sizes.push_back(tensor::DimOp::create(rewriter, loc, input, sourceDim).getResult());
}
return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult();
}
static SmallVector<int64_t> getTransposePermutation(ONNXTransposeOp transposeOp) {
auto inputType = cast<RankedTensorType>(transposeOp.getData().getType());
SmallVector<int64_t> permutation;
if (auto permAttr = transposeOp.getPermAttr()) {
permutation.reserve(permAttr.size());
for (IntegerAttr attr : permAttr.getAsRange<IntegerAttr>())
permutation.push_back(attr.getInt());
return permutation;
}
permutation.reserve(inputType.getRank());
for (int64_t dim = inputType.getRank() - 1; dim >= 0; --dim)
permutation.push_back(dim);
return permutation;
}
struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXTransposeOp transposeOp,
ONNXTransposeOpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto inputType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
auto resultType = dyn_cast<RankedTensorType>(transposeOp.getResult().getType());
if (!inputType || !resultType)
return failure();
SmallVector<int64_t> permutation = getTransposePermutation(transposeOp);
Value init = createTransposeInit(adaptor.getData(), resultType, permutation, rewriter, transposeOp.getLoc());
Value transposed =
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, permutation)
.getResult()[0];
rewriter.replaceOp(transposeOp, transposed);
return success();
}
};
} // namespace
void populateTransposePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<TransposeToLinalgTranspose>(ctx);
}
} // namespace onnx_mlir
@@ -1,18 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
bool requiresPostRewrite(spatial::SpatCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -1,10 +0,0 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnx_mlir {
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir
@@ -3,15 +3,17 @@ mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(SpatialToPimIncGen)
add_pim_library(OMSpatialToPim
Patterns.cpp
SpatialToPimPass.cpp
BatchCoreLoweringPatterns.cpp
ChannelLoweringPatterns.cpp
Common.cpp
ComputeLikeRegionUtils.cpp
CoreLoweringPatterns.cpp
GlobalTensorMaterialization.cpp
ReturnPathNormalization.cpp
TensorPackingPatterns.cpp
Patterns/ChannelLowering.cpp
Patterns/GlobalTensorMaterialization.cpp
Patterns/TensorPacking.cpp
Patterns/Transpose.cpp
EXCLUDE_FROM_OM_LIBS
@@ -19,6 +21,7 @@ add_pim_library(OMSpatialToPim
SpatialToPimIncGen
LINK_LIBS PUBLIC
MLIRLinalgDialect
MLIRSCFDialect
MLIRSCFUtils
MLIRTransformUtils
@@ -1,9 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir
@@ -1,5 +1,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/IRMapping.h"
@@ -24,7 +25,7 @@ static bool isChannelUseChainOp(Operation* op) {
tensor::ExpandShapeOp,
tensor::CastOp,
tosa::ReshapeOp,
ONNXTransposeOp,
linalg::TransposeOp,
pim::PimTransposeOp>(op);
}
@@ -1,9 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns);
}
@@ -0,0 +1,40 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace raptor {
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
} // namespace raptor
void populateInitialPatterns(RewritePatternSet& patterns) {
raptor::populateWithGenerated(patterns);
populateTransposeLoweringPatterns(patterns);
}
void populateGlobalTensorMaterializationPatternPhase(RewritePatternSet& patterns) {
populateGlobalTensorMaterializationPatterns(patterns);
}
void populateInitialTensorPackingPatterns(RewritePatternSet& patterns) {
populateTensorPackingPatterns(patterns);
}
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
raptor::populateWithGenerated(patterns);
populateTransposeLoweringPatterns(patterns);
}
void populateFinalTensorPackingPatterns(RewritePatternSet& patterns) {
populateTensorPackingPatterns(patterns);
}
void populateCommunicationPatterns(RewritePatternSet& patterns) {
populateChannelLoweringPatterns(patterns);
}
} // namespace onnx_mlir
@@ -8,6 +8,18 @@
namespace onnx_mlir {
void populateInitialPatterns(mlir::RewritePatternSet& patterns);
void populateGlobalTensorMaterializationPatternPhase(mlir::RewritePatternSet& patterns);
void populateInitialTensorPackingPatterns(mlir::RewritePatternSet& patterns);
void populateCoreBodyPatterns(mlir::RewritePatternSet& patterns);
void populateFinalTensorPackingPatterns(mlir::RewritePatternSet& patterns);
void populateCommunicationPatterns(mlir::RewritePatternSet& patterns);
void populateTransposeLoweringPatterns(mlir::RewritePatternSet& patterns);
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns);
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count);
mlir::Value extractPackedChunk(mlir::Value packedValue,
mlir::RankedTensorType chunkType,
@@ -20,7 +32,6 @@ mlir::Value createPackedExtractRowsSlice(spatial::SpatExtractRowsOp extractRowsO
mlir::OpBuilder& builder,
mlir::Location loc);
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
void eraseUnusedTensorPackingOps(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -1,6 +1,6 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -16,7 +16,7 @@
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -1,4 +1,4 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
@@ -0,0 +1,38 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct LinalgTransposeToPim final : OpRewritePattern<linalg::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter& rewriter) const override {
SmallVector<Attribute> permutationAttrs;
permutationAttrs.reserve(transposeOp.getPermutation().size());
for (int64_t dim : transposeOp.getPermutation())
permutationAttrs.push_back(rewriter.getI64IntegerAttr(dim));
auto permutation = rewriter.getArrayAttr(permutationAttrs);
auto pimTranspose = pim::PimTransposeOp::create(rewriter,
transposeOp.getLoc(),
TypeRange {transposeOp->getResult(0).getType()},
transposeOp.getInput(),
permutation,
transposeOp.getInit());
rewriter.replaceOp(transposeOp, pimTranspose.getOutput());
return success();
}
};
} // namespace
void populateTransposeLoweringPatterns(RewritePatternSet& patterns) {
patterns.add<LinalgTransposeToPim>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -1,5 +1,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -40,7 +41,7 @@ static bool isReturnHelperChainOp(Operation* op) {
tensor::ExpandShapeOp,
tensor::CastOp,
tosa::ReshapeOp,
ONNXTransposeOp,
linalg::TransposeOp,
pim::PimTransposeOp>(op);
}
@@ -276,11 +277,10 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op)) {
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(op)) {
SmallVector<int64_t> nextIndices(currentIndices.size());
SmallVector<int64_t> nextShape(currentShape.size());
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange<IntegerAttr>())) {
int64_t sourceIndex = attr.getInt();
for (auto [destIndex, sourceIndex] : llvm::enumerate(transposeOp.getPermutation())) {
nextIndices[destIndex] = currentIndices[sourceIndex];
nextShape[destIndex] = currentShape[sourceIndex];
}
@@ -9,12 +9,6 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE
def onnxToPimTranspose : Pat<
(ONNXTransposeOp:$srcOpRes $data, $perms),
(PimTransposeOp $data, $perms,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVMM : Pat<
(SpatVMMOp:$srcOpRes $weight, $vector),
(PimVMMOp $weight, $vector,
@@ -27,10 +27,8 @@
#include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "Conversion/SpatialToPim/Common.hpp"
#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "Conversion/SpatialToPim/Patterns.hpp"
#include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Spatial/SpatialOps.hpp"
#include "Pass/PIMPasses.h"
@@ -41,11 +39,6 @@ using namespace onnx_mlir;
using namespace pim;
namespace onnx_mlir {
namespace raptor {
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
} // namespace raptor
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
@@ -159,7 +152,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
spatial::SpatExtractRowsOp>();
RewritePatternSet initialPatterns(ctx);
populateWithGenerated(initialPatterns);
populateInitialPatterns(initialPatterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) {
moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form");
signalPassFailure();
@@ -167,7 +160,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
}
RewritePatternSet globalTensorPatterns(ctx);
populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
populateGlobalTensorMaterializationPatternPhase(globalTensorPatterns);
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
@@ -197,7 +190,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
}
RewritePatternSet initialTensorPackingPatterns(ctx);
populateTensorPackingPatterns(initialTensorPackingPatterns);
populateInitialTensorPackingPatterns(initialTensorPackingPatterns);
walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns));
eraseUnusedTensorPackingOps(funcOp, rewriter);
@@ -214,7 +207,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
}
RewritePatternSet coreBodyPatterns(ctx);
populateWithGenerated(coreBodyPatterns);
populateCoreBodyPatterns(coreBodyPatterns);
populateAffineToStdConversionPatterns(coreBodyPatterns);
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
@@ -257,7 +250,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
eraseOpsToRemove();
RewritePatternSet finalTensorPackingPatterns(ctx);
populateTensorPackingPatterns(finalTensorPackingPatterns);
populateFinalTensorPackingPatterns(finalTensorPackingPatterns);
walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns));
eraseUnusedTensorPackingOps(funcOp, rewriter);
@@ -277,7 +270,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
spatial::SpatExtractRowsOp>();
RewritePatternSet communicationPatterns(ctx);
populateChannelLoweringPatterns(communicationPatterns);
populateCommunicationPatterns(communicationPatterns);
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops");
signalPassFailure();
@@ -60,7 +60,7 @@ std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
}
void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
constexpr size_t kMaxOctTableBytes = 1ull << 30;
constexpr size_t kMaxOctTableBytes = 1ull << 35;
if (nodeCount == 0 || processorCount == 0)
return;
if (processorCount > std::numeric_limits<size_t>::max() / sizeof(Time))