Transpose and Refactor of Patterns
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
ilgeco
2026-05-29 13:23:31 +02:00
parent 82b44a6387
commit 832bd7f1f7
37 changed files with 285 additions and 153 deletions
+1
View File
@@ -18,6 +18,7 @@ add_pim_library(OMPimCommon
${PIM_PUBLIC_INCLUDE_DIRS} ${PIM_PUBLIC_INCLUDE_DIRS}
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRLinalgDialect
onnx onnx
SpatialOps SpatialOps
PimOps PimOps
+3 -2
View File
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
@@ -131,8 +132,8 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self); return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user)) if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self); return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(user)) if (auto transposeOp = mlir::dyn_cast<mlir::linalg::TransposeOp>(user))
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self); return transposeOp.getInput() == currentValue && self(transposeOp.getResult()[0], self);
return false; return false;
}); });
+1 -1
View File
@@ -18,7 +18,7 @@ void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out); std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file); llvm::raw_os_ostream os(file);
mlir::OpPrintingFlags flags; mlir::OpPrintingFlags flags;
flags.elideLargeElementsAttrs(); flags.elideLargeElementsAttrs().enableDebugInfo(true,false);
moduleOp.print(os, flags); moduleOp.print(os, flags);
os.flush(); os.flush();
file.close(); file.close();
@@ -3,11 +3,12 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen) add_public_tablegen_target(ONNXToSpatialIncGen)
add_pim_library(OMONNXToSpatial add_pim_library(OMONNXToSpatial
ConversionPatterns.cpp Patterns.cpp
CompileTime.cpp CompileTime.cpp
ONNXToSpatialVerifier.cpp ONNXToSpatialVerifier.cpp
PrePatterns.cpp Patterns/Pre.cpp
PostPatterns.cpp Patterns/Post.cpp
Patterns/GeneratedConversion.cpp
Patterns/Math/Conv.cpp Patterns/Math/Conv.cpp
Patterns/Math/Elementwise.cpp Patterns/Math/Elementwise.cpp
Patterns/Math/Gemm.cpp Patterns/Math/Gemm.cpp
@@ -22,6 +23,7 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Resize.cpp Patterns/Tensor/Resize.cpp
Patterns/Tensor/Reshape.cpp Patterns/Tensor/Reshape.cpp
Patterns/Tensor/Split.cpp Patterns/Tensor/Split.cpp
Patterns/Tensor/Transpose.cpp
ONNXToSpatialPass.cpp ONNXToSpatialPass.cpp
Common/ComputeRegionBuilder.cpp Common/ComputeRegionBuilder.cpp
Common/ShapeTilingUtils.cpp Common/ShapeTilingUtils.cpp
@@ -33,6 +35,7 @@ add_pim_library(OMONNXToSpatial
ONNXToSpatialIncGen ONNXToSpatialIncGen
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRLinalgDialect
MLIRSCFDialect MLIRSCFDialect
MLIRTosaDialect MLIRTosaDialect
OMCompilerOptions OMCompilerOptions
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
@@ -43,8 +44,8 @@ bool isWeightLikeComputeOperand(Value value) {
value = collapseShapeOp.getSrc(); value = collapseShapeOp.getSrc();
continue; continue;
} }
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) { if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
value = transposeOp.getData(); value = transposeOp.getInput();
continue; continue;
} }
@@ -80,7 +81,7 @@ FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr
return referencedValue.getResult(); return referencedValue.getResult();
} }
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp)) if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(definingOp))
return failure(); return failure();
IRMapping localMapper; IRMapping localMapper;
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
@@ -171,6 +172,16 @@ static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm:
return succeeded(transposedAttr) ? *transposedAttr : nullptr; return succeeded(transposedAttr) ? *transposedAttr : nullptr;
} }
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getInput(), visited);
if (!inputAttr)
return nullptr;
SmallVector<int64_t> perm(transposeOp.getPermutation().begin(), transposeOp.getPermutation().end());
auto transposedAttr = transposeDenseElements(inputAttr, perm);
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) { if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited); auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
if (!inputAttr) if (!inputAttr)
@@ -226,6 +237,9 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op)) if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength); return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength);
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(op))
return getCompileTimeSourceImpl(transposeOp.getInput().getDefiningOp(), visited, chainLength);
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op)) if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength); return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength);
@@ -1,6 +1,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
@@ -14,10 +15,8 @@
#include "Common/Common.hpp" #include "Common/Common.hpp"
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -86,30 +85,6 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
returnOp.setOperand(index, computeResult); returnOp.setOperand(index, computeResult);
} }
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
Block& entryBlock = funcOp.getFunctionBody().front();
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
if (!transposeOp || isCompileTimeOp(transposeOp))
continue;
// Transpose stays globally legal because constant/view-only cases are
// allowed on the host. Any residual runtime transpose must be sunk into
// spat.compute before the host legality check.
auto resultType = transposeOp.getResult().getType();
rewriter.setInsertionPoint(transposeOp);
auto computeOp = createSpatCompute<1>(
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
Value transposed =
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
});
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
}
}
void ONNXToSpatialPass::runOnOperation() { void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext(); MLIRContext* ctx = &getContext();
@@ -117,6 +92,7 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget preTarget(*ctx); ConversionTarget preTarget(*ctx);
preTarget.addLegalDialect<spatial::SpatialDialect, preTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect, ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect, tensor::TensorDialect,
affine::AffineDialect, affine::AffineDialect,
arith::ArithDialect, arith::ArithDialect,
@@ -156,11 +132,13 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget target(*ctx); ConversionTarget target(*ctx);
target.addLegalDialect<spatial::SpatialDialect, target.addLegalDialect<spatial::SpatialDialect,
ONNXDialect, ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect, tensor::TensorDialect,
affine::AffineDialect, affine::AffineDialect,
arith::ArithDialect, arith::ArithDialect,
scf::SCFDialect>(); scf::SCFDialect>();
target.addIllegalOp<ONNXMatMulOp>(); target.addIllegalOp<ONNXMatMulOp>();
target.addIllegalOp<ONNXTransposeOp>();
target.addIllegalOp<ONNXAddOp>(); target.addIllegalOp<ONNXAddOp>();
target.addIllegalOp<ONNXDivOp>(); target.addIllegalOp<ONNXDivOp>();
target.addIllegalOp<ONNXMulOp>(); target.addIllegalOp<ONNXMulOp>();
@@ -187,9 +165,14 @@ void ONNXToSpatialPass::runOnOperation() {
return; return;
} }
RewritePatternSet transposePatterns(ctx);
populateTransposePatterns(transposePatterns, ctx);
walkAndApplyPatterns(moduleOp, std::move(transposePatterns));
ConversionTarget earlyPostTarget(*ctx); ConversionTarget earlyPostTarget(*ctx);
earlyPostTarget.addLegalDialect<spatial::SpatialDialect, earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect, ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect, tensor::TensorDialect,
affine::AffineDialect, affine::AffineDialect,
arith::ArithDialect, arith::ArithDialect,
@@ -205,6 +188,7 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget postTarget(*ctx); ConversionTarget postTarget(*ctx);
postTarget.addLegalDialect<spatial::SpatialDialect, postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect, ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect, tensor::TensorDialect,
affine::AffineDialect, affine::AffineDialect,
arith::ArithDialect, arith::ArithDialect,
@@ -222,8 +206,6 @@ void ONNXToSpatialPass::runOnOperation() {
return; return;
} }
wrapTopLevelRuntimeTransposes(*entryFunc);
if (failed(verifyONNXToSpatial(*entryFunc))) { if (failed(verifyONNXToSpatial(*entryFunc))) {
moduleOp.emitError("ONNX-to-Spatial host legality verification failed"); moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
signalPassFailure(); signalPassFailure();
@@ -1,19 +1,16 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { void populatePrePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateGeneratedPrePatterns(patterns, ctx);
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc" }
} // namespace
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<removeLRN>(ctx);
void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateGeneratedConversionPatterns(patterns, ctx);
populateElementwisePatterns(patterns, ctx); populateElementwisePatterns(patterns, ctx);
populateGemmPatterns(patterns, ctx); populateGemmPatterns(patterns, ctx);
populateConvPatterns(patterns, ctx); populateConvPatterns(patterns, ctx);
@@ -27,6 +24,11 @@ void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRCon
populateResizePatterns(patterns, ctx); populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx); populateReshapePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx); populateSplitPatterns(patterns, ctx);
populateTransposePatterns(patterns, ctx);
}
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateWeightPromotionPatterns(patterns, ctx);
} }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -1,38 +1,39 @@
#pragma once #pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGeneratedConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateWeightPromotionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReduceMeanPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReduceMeanPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSigmoidPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateSigmoidPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSoftmaxPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateSoftmaxPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
bool requiresPostRewrite(spatial::SpatCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -0,0 +1,18 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
} // namespace
void populateGeneratedConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<removeLRN>(ctx);
}
} // namespace onnx_mlir
@@ -7,7 +7,7 @@
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -10,7 +10,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -7,7 +7,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -4,7 +4,7 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -1,5 +1,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
@@ -9,7 +10,7 @@
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -20,7 +21,7 @@ namespace onnx_mlir {
namespace { namespace {
static bool isWeightMaterializationHelperUser(Operation* op) { static bool isWeightMaterializationHelperUser(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(op); return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(op);
} }
static bool canPromoteInputBlockArgument(BlockArgument arg) { static bool canPromoteInputBlockArgument(BlockArgument arg) {
@@ -276,7 +277,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
} // namespace } // namespace
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { void populateWeightPromotionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx); patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
} }
@@ -1,6 +1,5 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
using namespace mlir; using namespace mlir;
@@ -12,7 +11,7 @@ namespace {
} // namespace } // namespace
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) { void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<onnxToArithConstant>(ctx); patterns.add<onnxToArithConstant>(ctx);
patterns.add<convAddToConvWithBiasLeft>(ctx); patterns.add<convAddToConvWithBiasLeft>(ctx);
patterns.add<convAddToConvWithBiasRight>(ctx); patterns.add<convAddToConvWithBiasRight>(ctx);
@@ -6,7 +6,7 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -5,7 +5,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -6,7 +6,7 @@
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -3,7 +3,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -0,0 +1,75 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static Value createTransposeInit(Value input,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
SmallVector<OpFoldResult> sizes;
sizes.reserve(resultType.getRank());
for (auto [resultDim, sourceDim] : llvm::zip_equal(resultType.getShape(), permutation)) {
if (!ShapedType::isDynamic(resultDim)) {
sizes.push_back(rewriter.getIndexAttr(resultDim));
continue;
}
sizes.push_back(tensor::DimOp::create(rewriter, loc, input, sourceDim).getResult());
}
return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult();
}
static SmallVector<int64_t> getTransposePermutation(ONNXTransposeOp transposeOp) {
auto inputType = cast<RankedTensorType>(transposeOp.getData().getType());
SmallVector<int64_t> permutation;
if (auto permAttr = transposeOp.getPermAttr()) {
permutation.reserve(permAttr.size());
for (IntegerAttr attr : permAttr.getAsRange<IntegerAttr>())
permutation.push_back(attr.getInt());
return permutation;
}
permutation.reserve(inputType.getRank());
for (int64_t dim = inputType.getRank() - 1; dim >= 0; --dim)
permutation.push_back(dim);
return permutation;
}
struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXTransposeOp transposeOp,
ONNXTransposeOpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto inputType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
auto resultType = dyn_cast<RankedTensorType>(transposeOp.getResult().getType());
if (!inputType || !resultType)
return failure();
SmallVector<int64_t> permutation = getTransposePermutation(transposeOp);
Value init = createTransposeInit(adaptor.getData(), resultType, permutation, rewriter, transposeOp.getLoc());
Value transposed =
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, permutation)
.getResult()[0];
rewriter.replaceOp(transposeOp, transposed);
return success();
}
};
} // namespace
void populateTransposePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<TransposeToLinalgTranspose>(ctx);
}
} // namespace onnx_mlir
@@ -1,18 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
bool requiresPostRewrite(spatial::SpatCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -1,10 +0,0 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnx_mlir {
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir
@@ -3,15 +3,17 @@ mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(SpatialToPimIncGen) add_public_tablegen_target(SpatialToPimIncGen)
add_pim_library(OMSpatialToPim add_pim_library(OMSpatialToPim
Patterns.cpp
SpatialToPimPass.cpp SpatialToPimPass.cpp
BatchCoreLoweringPatterns.cpp BatchCoreLoweringPatterns.cpp
ChannelLoweringPatterns.cpp
Common.cpp Common.cpp
ComputeLikeRegionUtils.cpp ComputeLikeRegionUtils.cpp
CoreLoweringPatterns.cpp CoreLoweringPatterns.cpp
GlobalTensorMaterialization.cpp
ReturnPathNormalization.cpp ReturnPathNormalization.cpp
TensorPackingPatterns.cpp Patterns/ChannelLowering.cpp
Patterns/GlobalTensorMaterialization.cpp
Patterns/TensorPacking.cpp
Patterns/Transpose.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS
@@ -19,6 +21,7 @@ add_pim_library(OMSpatialToPim
SpatialToPimIncGen SpatialToPimIncGen
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRLinalgDialect
MLIRSCFDialect MLIRSCFDialect
MLIRSCFUtils MLIRSCFUtils
MLIRTransformUtils MLIRTransformUtils
@@ -1,9 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir
@@ -1,5 +1,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
@@ -24,7 +25,7 @@ static bool isChannelUseChainOp(Operation* op) {
tensor::ExpandShapeOp, tensor::ExpandShapeOp,
tensor::CastOp, tensor::CastOp,
tosa::ReshapeOp, tosa::ReshapeOp,
ONNXTransposeOp, linalg::TransposeOp,
pim::PimTransposeOp>(op); pim::PimTransposeOp>(op);
} }
@@ -1,9 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns);
}
@@ -0,0 +1,40 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace raptor {
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
} // namespace raptor
void populateInitialPatterns(RewritePatternSet& patterns) {
raptor::populateWithGenerated(patterns);
populateTransposeLoweringPatterns(patterns);
}
void populateGlobalTensorMaterializationPatternPhase(RewritePatternSet& patterns) {
populateGlobalTensorMaterializationPatterns(patterns);
}
void populateInitialTensorPackingPatterns(RewritePatternSet& patterns) {
populateTensorPackingPatterns(patterns);
}
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
raptor::populateWithGenerated(patterns);
populateTransposeLoweringPatterns(patterns);
}
void populateFinalTensorPackingPatterns(RewritePatternSet& patterns) {
populateTensorPackingPatterns(patterns);
}
void populateCommunicationPatterns(RewritePatternSet& patterns) {
populateChannelLoweringPatterns(patterns);
}
} // namespace onnx_mlir
@@ -8,6 +8,18 @@
namespace onnx_mlir { namespace onnx_mlir {
void populateInitialPatterns(mlir::RewritePatternSet& patterns);
void populateGlobalTensorMaterializationPatternPhase(mlir::RewritePatternSet& patterns);
void populateInitialTensorPackingPatterns(mlir::RewritePatternSet& patterns);
void populateCoreBodyPatterns(mlir::RewritePatternSet& patterns);
void populateFinalTensorPackingPatterns(mlir::RewritePatternSet& patterns);
void populateCommunicationPatterns(mlir::RewritePatternSet& patterns);
void populateTransposeLoweringPatterns(mlir::RewritePatternSet& patterns);
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns);
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count); mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count);
mlir::Value extractPackedChunk(mlir::Value packedValue, mlir::Value extractPackedChunk(mlir::Value packedValue,
mlir::RankedTensorType chunkType, mlir::RankedTensorType chunkType,
@@ -20,7 +32,6 @@ mlir::Value createPackedExtractRowsSlice(spatial::SpatExtractRowsOp extractRowsO
mlir::OpBuilder& builder, mlir::OpBuilder& builder,
mlir::Location loc); mlir::Location loc);
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc); mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
void eraseUnusedTensorPackingOps(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter); void eraseUnusedTensorPackingOps(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -1,6 +1,6 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -16,7 +16,7 @@
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir; using namespace mlir;
@@ -1,4 +1,4 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
@@ -0,0 +1,38 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct LinalgTransposeToPim final : OpRewritePattern<linalg::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter& rewriter) const override {
SmallVector<Attribute> permutationAttrs;
permutationAttrs.reserve(transposeOp.getPermutation().size());
for (int64_t dim : transposeOp.getPermutation())
permutationAttrs.push_back(rewriter.getI64IntegerAttr(dim));
auto permutation = rewriter.getArrayAttr(permutationAttrs);
auto pimTranspose = pim::PimTransposeOp::create(rewriter,
transposeOp.getLoc(),
TypeRange {transposeOp->getResult(0).getType()},
transposeOp.getInput(),
permutation,
transposeOp.getInit());
rewriter.replaceOp(transposeOp, pimTranspose.getOutput());
return success();
}
};
} // namespace
void populateTransposeLoweringPatterns(RewritePatternSet& patterns) {
patterns.add<LinalgTransposeToPim>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -1,5 +1,6 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
@@ -40,7 +41,7 @@ static bool isReturnHelperChainOp(Operation* op) {
tensor::ExpandShapeOp, tensor::ExpandShapeOp,
tensor::CastOp, tensor::CastOp,
tosa::ReshapeOp, tosa::ReshapeOp,
ONNXTransposeOp, linalg::TransposeOp,
pim::PimTransposeOp>(op); pim::PimTransposeOp>(op);
} }
@@ -276,11 +277,10 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
continue; continue;
} }
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op)) { if (auto transposeOp = dyn_cast<linalg::TransposeOp>(op)) {
SmallVector<int64_t> nextIndices(currentIndices.size()); SmallVector<int64_t> nextIndices(currentIndices.size());
SmallVector<int64_t> nextShape(currentShape.size()); SmallVector<int64_t> nextShape(currentShape.size());
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange<IntegerAttr>())) { for (auto [destIndex, sourceIndex] : llvm::enumerate(transposeOp.getPermutation())) {
int64_t sourceIndex = attr.getInt();
nextIndices[destIndex] = currentIndices[sourceIndex]; nextIndices[destIndex] = currentIndices[sourceIndex];
nextShape[destIndex] = currentShape[sourceIndex]; nextShape[destIndex] = currentShape[sourceIndex];
} }
@@ -9,12 +9,6 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
include "src/Accelerators/PIM/Dialect/Pim/Pim.td" include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
#endif // OP_BASE #endif // OP_BASE
def onnxToPimTranspose : Pat<
(ONNXTransposeOp:$srcOpRes $data, $perms),
(PimTransposeOp $data, $perms,
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
>;
def spatToPimVMM : Pat< def spatToPimVMM : Pat<
(SpatVMMOp:$srcOpRes $weight, $vector), (SpatVMMOp:$srcOpRes $weight, $vector),
(PimVMMOp $weight, $vector, (PimVMMOp $weight, $vector,
@@ -27,10 +27,8 @@
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "Conversion/SpatialToPim/Common.hpp" #include "Conversion/SpatialToPim/Common.hpp"
#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp" #include "Conversion/SpatialToPim/Patterns.hpp"
#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "Dialect/Pim/PimOps.hpp" #include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Spatial/SpatialOps.hpp" #include "Dialect/Spatial/SpatialOps.hpp"
#include "Pass/PIMPasses.h" #include "Pass/PIMPasses.h"
@@ -41,11 +39,6 @@ using namespace onnx_mlir;
using namespace pim; using namespace pim;
namespace onnx_mlir { namespace onnx_mlir {
namespace raptor {
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
} // namespace raptor
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) { static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>(); auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
@@ -159,7 +152,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
spatial::SpatExtractRowsOp>(); spatial::SpatExtractRowsOp>();
RewritePatternSet initialPatterns(ctx); RewritePatternSet initialPatterns(ctx);
populateWithGenerated(initialPatterns); populateInitialPatterns(initialPatterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) { if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) {
moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form"); moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form");
signalPassFailure(); signalPassFailure();
@@ -167,7 +160,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
} }
RewritePatternSet globalTensorPatterns(ctx); RewritePatternSet globalTensorPatterns(ctx);
populateGlobalTensorMaterializationPatterns(globalTensorPatterns); populateGlobalTensorMaterializationPatternPhase(globalTensorPatterns);
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns)); walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator()); auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
@@ -197,7 +190,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
} }
RewritePatternSet initialTensorPackingPatterns(ctx); RewritePatternSet initialTensorPackingPatterns(ctx);
populateTensorPackingPatterns(initialTensorPackingPatterns); populateInitialTensorPackingPatterns(initialTensorPackingPatterns);
walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns)); walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns));
eraseUnusedTensorPackingOps(funcOp, rewriter); eraseUnusedTensorPackingOps(funcOp, rewriter);
@@ -214,7 +207,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
} }
RewritePatternSet coreBodyPatterns(ctx); RewritePatternSet coreBodyPatterns(ctx);
populateWithGenerated(coreBodyPatterns); populateCoreBodyPatterns(coreBodyPatterns);
populateAffineToStdConversionPatterns(coreBodyPatterns); populateAffineToStdConversionPatterns(coreBodyPatterns);
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns)); FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
@@ -257,7 +250,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
eraseOpsToRemove(); eraseOpsToRemove();
RewritePatternSet finalTensorPackingPatterns(ctx); RewritePatternSet finalTensorPackingPatterns(ctx);
populateTensorPackingPatterns(finalTensorPackingPatterns); populateFinalTensorPackingPatterns(finalTensorPackingPatterns);
walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns)); walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns));
eraseUnusedTensorPackingOps(funcOp, rewriter); eraseUnusedTensorPackingOps(funcOp, rewriter);
@@ -277,7 +270,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
spatial::SpatExtractRowsOp>(); spatial::SpatExtractRowsOp>();
RewritePatternSet communicationPatterns(ctx); RewritePatternSet communicationPatterns(ctx);
populateChannelLoweringPatterns(communicationPatterns); populateCommunicationPatterns(communicationPatterns);
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) { if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops"); funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops");
signalPassFailure(); signalPassFailure();
@@ -60,7 +60,7 @@ std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
} }
void verifyOctTableSize(size_t nodeCount, size_t processorCount) { void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
constexpr size_t kMaxOctTableBytes = 1ull << 30; constexpr size_t kMaxOctTableBytes = 1ull << 35;
if (nodeCount == 0 || processorCount == 0) if (nodeCount == 0 || processorCount == 0)
return; return;
if (processorCount > std::numeric_limits<size_t>::max() / sizeof(Time)) if (processorCount > std::numeric_limits<size_t>::max() / sizeof(Time))