From 832bd7f1f775921eae621fd59eaffe78617b5c19 Mon Sep 17 00:00:00 2001 From: ilgeco Date: Fri, 29 May 2026 13:23:31 +0200 Subject: [PATCH] Transpose and Refactor of Patterns --- src/PIM/Common/CMakeLists.txt | 1 + src/PIM/Common/IR/WeightUtils.cpp | 5 +- src/PIM/Common/Support/DebugDump.cpp | 2 +- .../Conversion/ONNXToSpatial/CMakeLists.txt | 9 ++- .../Common/WeightMaterialization.cpp | 7 +- .../Conversion/ONNXToSpatial/CompileTime.cpp | 14 ++++ .../ONNXToSpatial/ONNXToSpatialPass.cpp | 40 +++------- .../{ConversionPatterns.cpp => Patterns.cpp} | 20 ++--- .../{ConversionPatterns.hpp => Patterns.hpp} | 27 +++---- .../Patterns/GeneratedConversion.cpp | 18 +++++ .../Patterns/Math/Elementwise.cpp | 2 +- .../ONNXToSpatial/Patterns/Math/MatMul.cpp | 2 +- .../Patterns/Math/ReduceMean.cpp | 2 +- .../ONNXToSpatial/Patterns/NN/Softmax.cpp | 2 +- .../{PostPatterns.cpp => Patterns/Post.cpp} | 7 +- .../{PrePatterns.cpp => Patterns/Pre.cpp} | 5 +- .../ONNXToSpatial/Patterns/Tensor/Gather.cpp | 2 +- .../ONNXToSpatial/Patterns/Tensor/Reshape.cpp | 2 +- .../ONNXToSpatial/Patterns/Tensor/Resize.cpp | 2 +- .../ONNXToSpatial/Patterns/Tensor/Split.cpp | 2 +- .../Patterns/Tensor/Transpose.cpp | 75 +++++++++++++++++++ .../Conversion/ONNXToSpatial/PostPatterns.hpp | 18 ----- .../Conversion/ONNXToSpatial/PrePatterns.hpp | 10 --- .../Conversion/SpatialToPim/CMakeLists.txt | 9 ++- .../SpatialToPim/ChannelLoweringPatterns.hpp | 9 --- .../SpatialToPim/CoreLoweringPatterns.cpp | 3 +- .../GlobalTensorMaterialization.hpp | 9 --- src/PIM/Conversion/SpatialToPim/Patterns.cpp | 40 ++++++++++ ...TensorPackingPatterns.hpp => Patterns.hpp} | 13 +++- .../ChannelLowering.cpp} | 2 +- .../GlobalTensorMaterialization.cpp | 2 +- .../TensorPacking.cpp} | 2 +- .../SpatialToPim/Patterns/Transpose.cpp | 38 ++++++++++ .../SpatialToPim/ReturnPathNormalization.cpp | 8 +- .../Conversion/SpatialToPim/SpatialToPim.td | 6 -- .../SpatialToPim/SpatialToPimPass.cpp | 21 ++---- .../Scheduling/PeftScheduler.cpp | 2 +- 37 files changed, 285 insertions(+), 153 deletions(-) rename src/PIM/Conversion/ONNXToSpatial/{ConversionPatterns.cpp => Patterns.cpp} (57%) rename src/PIM/Conversion/ONNXToSpatial/{ConversionPatterns.hpp => Patterns.hpp} (63%) create mode 100644 src/PIM/Conversion/ONNXToSpatial/Patterns/GeneratedConversion.cpp rename src/PIM/Conversion/ONNXToSpatial/{PostPatterns.cpp => Patterns/Post.cpp} (98%) rename src/PIM/Conversion/ONNXToSpatial/{PrePatterns.cpp => Patterns/Pre.cpp} (66%) create mode 100644 src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Transpose.cpp delete mode 100644 src/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp delete mode 100644 src/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp delete mode 100644 src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp delete mode 100644 src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp create mode 100644 src/PIM/Conversion/SpatialToPim/Patterns.cpp rename src/PIM/Conversion/SpatialToPim/{TensorPackingPatterns.hpp => Patterns.hpp} (64%) rename src/PIM/Conversion/SpatialToPim/{ChannelLoweringPatterns.cpp => Patterns/ChannelLowering.cpp} (97%) rename src/PIM/Conversion/SpatialToPim/{ => Patterns}/GlobalTensorMaterialization.cpp (99%) rename src/PIM/Conversion/SpatialToPim/{TensorPackingPatterns.cpp => Patterns/TensorPacking.cpp} (99%) create mode 100644 src/PIM/Conversion/SpatialToPim/Patterns/Transpose.cpp diff --git a/src/PIM/Common/CMakeLists.txt b/src/PIM/Common/CMakeLists.txt index 1bdc67f..effdd6d 100644 --- a/src/PIM/Common/CMakeLists.txt +++ b/src/PIM/Common/CMakeLists.txt @@ -18,6 +18,7 @@ add_pim_library(OMPimCommon ${PIM_PUBLIC_INCLUDE_DIRS} LINK_LIBS PUBLIC + MLIRLinalgDialect onnx SpatialOps PimOps diff --git a/src/PIM/Common/IR/WeightUtils.cpp b/src/PIM/Common/IR/WeightUtils.cpp index 4cd8168..1d3b7be 100644 --- a/src/PIM/Common/IR/WeightUtils.cpp +++ b/src/PIM/Common/IR/WeightUtils.cpp @@ -1,4 +1,5 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -131,8 +132,8 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) { return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self); if (auto collapseShapeOp = mlir::dyn_cast(user)) return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self); - if (auto transposeOp = mlir::dyn_cast(user)) - return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self); + if (auto transposeOp = mlir::dyn_cast(user)) + return transposeOp.getInput() == currentValue && self(transposeOp.getResult()[0], self); return false; }); diff --git a/src/PIM/Common/Support/DebugDump.cpp b/src/PIM/Common/Support/DebugDump.cpp index c6a3593..36c26df 100644 --- a/src/PIM/Common/Support/DebugDump.cpp +++ b/src/PIM/Common/Support/DebugDump.cpp @@ -18,7 +18,7 @@ void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) { std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out); llvm::raw_os_ostream os(file); mlir::OpPrintingFlags flags; - flags.elideLargeElementsAttrs(); + flags.elideLargeElementsAttrs().enableDebugInfo(true,false); moduleOp.print(os, flags); os.flush(); file.close(); diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt index d19c4cd..0e6f192 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt +++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt @@ -3,11 +3,12 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") add_public_tablegen_target(ONNXToSpatialIncGen) add_pim_library(OMONNXToSpatial - ConversionPatterns.cpp + Patterns.cpp CompileTime.cpp ONNXToSpatialVerifier.cpp - PrePatterns.cpp - PostPatterns.cpp + Patterns/Pre.cpp + Patterns/Post.cpp + Patterns/GeneratedConversion.cpp Patterns/Math/Conv.cpp Patterns/Math/Elementwise.cpp Patterns/Math/Gemm.cpp @@ -22,6 +23,7 @@ add_pim_library(OMONNXToSpatial Patterns/Tensor/Resize.cpp Patterns/Tensor/Reshape.cpp Patterns/Tensor/Split.cpp + Patterns/Tensor/Transpose.cpp ONNXToSpatialPass.cpp Common/ComputeRegionBuilder.cpp Common/ShapeTilingUtils.cpp @@ -33,6 +35,7 @@ add_pim_library(OMONNXToSpatial ONNXToSpatialIncGen LINK_LIBS PUBLIC + MLIRLinalgDialect MLIRSCFDialect MLIRTosaDialect OMCompilerOptions diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.cpp index 75931bf..27e2f93 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.cpp @@ -1,4 +1,5 @@ #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" @@ -43,8 +44,8 @@ bool isWeightLikeComputeOperand(Value value) { value = collapseShapeOp.getSrc(); continue; } - if (auto transposeOp = dyn_cast(definingOp)) { - value = transposeOp.getData(); + if (auto transposeOp = dyn_cast(definingOp)) { + value = transposeOp.getInput(); continue; } @@ -80,7 +81,7 @@ FailureOr materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr return referencedValue.getResult(); } - if (!isa(definingOp)) + if (!isa(definingOp)) return failure(); IRMapping localMapper; diff --git a/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp b/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp index 45ffe57..c947a2a 100644 --- a/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/CompileTime.cpp @@ -1,4 +1,5 @@ #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -171,6 +172,16 @@ static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm: return succeeded(transposedAttr) ? *transposedAttr : nullptr; } + if (auto transposeOp = dyn_cast(definingOp)) { + auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getInput(), visited); + if (!inputAttr) + return nullptr; + + SmallVector perm(transposeOp.getPermutation().begin(), transposeOp.getPermutation().end()); + auto transposedAttr = transposeDenseElements(inputAttr, perm); + return succeeded(transposedAttr) ? *transposedAttr : nullptr; + } + if (auto collapseShapeOp = dyn_cast(definingOp)) { auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited); if (!inputAttr) @@ -226,6 +237,9 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl& visit if (auto transposeOp = dyn_cast(op)) return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength); + if (auto transposeOp = dyn_cast(op)) + return getCompileTimeSourceImpl(transposeOp.getInput().getDefiningOp(), visited, chainLength); + if (auto collapseShapeOp = dyn_cast(op)) return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength); diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp index 96f5a46..5869c21 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp @@ -1,6 +1,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" @@ -14,10 +15,8 @@ #include "Common/Common.hpp" #include "Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -86,30 +85,6 @@ static void populateEmptyFunction(func::FuncOp funcOp) { returnOp.setOperand(index, computeResult); } -static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) { - IRRewriter rewriter(funcOp.getContext()); - Block& entryBlock = funcOp.getFunctionBody().front(); - - for (Operation& op : llvm::make_early_inc_range(entryBlock)) { - auto transposeOp = dyn_cast(&op); - if (!transposeOp || isCompileTimeOp(transposeOp)) - continue; - - // Transpose stays globally legal because constant/view-only cases are - // allowed on the host. Any residual runtime transpose must be sunk into - // spat.compute before the host legality check. - auto resultType = transposeOp.getResult().getType(); - rewriter.setInsertionPoint(transposeOp); - auto computeOp = createSpatCompute<1>( - rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) { - Value transposed = - ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr()); - spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed); - }); - rewriter.replaceOp(transposeOp, computeOp.getResult(0)); - } -} - void ONNXToSpatialPass::runOnOperation() { ModuleOp moduleOp = getOperation(); MLIRContext* ctx = &getContext(); @@ -117,6 +92,7 @@ void ONNXToSpatialPass::runOnOperation() { ConversionTarget preTarget(*ctx); preTarget.addLegalDialect(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -187,9 +165,14 @@ void ONNXToSpatialPass::runOnOperation() { return; } + RewritePatternSet transposePatterns(ctx); + populateTransposePatterns(transposePatterns, ctx); + walkAndApplyPatterns(moduleOp, std::move(transposePatterns)); + ConversionTarget earlyPostTarget(*ctx); earlyPostTarget.addLegalDialect(ctx); +void populatePrePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { + populateGeneratedPrePatterns(patterns, ctx); +} +void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { + populateGeneratedConversionPatterns(patterns, ctx); populateElementwisePatterns(patterns, ctx); populateGemmPatterns(patterns, ctx); populateConvPatterns(patterns, ctx); @@ -27,6 +24,11 @@ void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRCon populateResizePatterns(patterns, ctx); populateReshapePatterns(patterns, ctx); populateSplitPatterns(patterns, ctx); + populateTransposePatterns(patterns, ctx); +} + +void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { + populateWeightPromotionPatterns(patterns, ctx); } } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp similarity index 63% rename from src/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp rename to src/PIM/Conversion/ONNXToSpatial/Patterns.hpp index 892c18d..e58729e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp @@ -1,38 +1,39 @@ #pragma once +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/Transforms/DialectConversion.h" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + namespace onnx_mlir { +void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + +void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateGeneratedConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateWeightPromotionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - void populateReduceMeanPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - void populateSigmoidPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - void populateSoftmaxPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); +void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); + +bool requiresPostRewrite(spatial::SpatCompute computeOp); +bool requiresPostRewrite(spatial::SpatComputeBatch computeOp); +void annotateWeightsConstants(mlir::func::FuncOp funcOp); } // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/GeneratedConversion.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/GeneratedConversion.cpp new file mode 100644 index 0000000..5fc96a2 --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/GeneratedConversion.cpp @@ -0,0 +1,18 @@ +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" + +using namespace mlir; + +namespace onnx_mlir { + +namespace { + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc" + +} // namespace + +void populateGeneratedConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) { + patterns.add(ctx); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp index 44c2fc1..4615c76 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Elementwise.cpp @@ -7,7 +7,7 @@ #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp index 9859ad2..86ffded 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/MatMul.cpp @@ -10,7 +10,7 @@ #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp index 1218481..c89f06e 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/ReduceMean.cpp @@ -7,7 +7,7 @@ #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp index ad88c52..9ebdae8 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/NN/Softmax.cpp @@ -4,7 +4,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Post.cpp similarity index 98% rename from src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp rename to src/PIM/Conversion/ONNXToSpatial/Patterns/Post.cpp index a212b6f..ba556c3 100644 --- a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Post.cpp @@ -1,5 +1,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" @@ -9,7 +10,7 @@ #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" @@ -20,7 +21,7 @@ namespace onnx_mlir { namespace { static bool isWeightMaterializationHelperUser(Operation* op) { - return isa(op); + return isa(op); } static bool canPromoteInputBlockArgument(BlockArgument arg) { @@ -276,7 +277,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern(ctx); } diff --git a/src/PIM/Conversion/ONNXToSpatial/PrePatterns.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Pre.cpp similarity index 66% rename from src/PIM/Conversion/ONNXToSpatial/PrePatterns.cpp rename to src/PIM/Conversion/ONNXToSpatial/Patterns/Pre.cpp index bf524d2..a2aa542 100644 --- a/src/PIM/Conversion/ONNXToSpatial/PrePatterns.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Pre.cpp @@ -1,6 +1,5 @@ #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" using namespace mlir; @@ -12,7 +11,7 @@ namespace { } // namespace -void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) { +void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) { patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp index 84cef43..e388b83 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Gather.cpp @@ -6,7 +6,7 @@ #include "llvm/ADT/SmallVector.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp index 901d1ca..a766982 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Reshape.cpp @@ -5,7 +5,7 @@ #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp index be38062..56f1625 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Resize.cpp @@ -6,7 +6,7 @@ #include "llvm/ADT/STLExtras.h" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp index 2f154c9..ffdabfe 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Split.cpp @@ -3,7 +3,7 @@ #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp" -#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Transpose.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Transpose.cpp new file mode 100644 index 0000000..07c524d --- /dev/null +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Transpose.cpp @@ -0,0 +1,75 @@ +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "llvm/ADT/SmallVector.h" + +#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp" +#include "src/Dialect/ONNX/ONNXOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +static Value createTransposeInit(Value input, + RankedTensorType resultType, + ArrayRef permutation, + ConversionPatternRewriter& rewriter, + Location loc) { + SmallVector sizes; + sizes.reserve(resultType.getRank()); + for (auto [resultDim, sourceDim] : llvm::zip_equal(resultType.getShape(), permutation)) { + if (!ShapedType::isDynamic(resultDim)) { + sizes.push_back(rewriter.getIndexAttr(resultDim)); + continue; + } + sizes.push_back(tensor::DimOp::create(rewriter, loc, input, sourceDim).getResult()); + } + return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult(); +} + +static SmallVector getTransposePermutation(ONNXTransposeOp transposeOp) { + auto inputType = cast(transposeOp.getData().getType()); + SmallVector permutation; + if (auto permAttr = transposeOp.getPermAttr()) { + permutation.reserve(permAttr.size()); + for (IntegerAttr attr : permAttr.getAsRange()) + permutation.push_back(attr.getInt()); + return permutation; + } + + permutation.reserve(inputType.getRank()); + for (int64_t dim = inputType.getRank() - 1; dim >= 0; --dim) + permutation.push_back(dim); + return permutation; +} + +struct TransposeToLinalgTranspose : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(ONNXTransposeOp transposeOp, + ONNXTransposeOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto inputType = dyn_cast(adaptor.getData().getType()); + auto resultType = dyn_cast(transposeOp.getResult().getType()); + if (!inputType || !resultType) + return failure(); + + SmallVector permutation = getTransposePermutation(transposeOp); + Value init = createTransposeInit(adaptor.getData(), resultType, permutation, rewriter, transposeOp.getLoc()); + Value transposed = + linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, permutation) + .getResult()[0]; + rewriter.replaceOp(transposeOp, transposed); + return success(); + } +}; + +} // namespace + +void populateTransposePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { + patterns.add(ctx); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp b/src/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp deleted file mode 100644 index 8c14f7f..0000000 --- a/src/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/MLIRContext.h" - -#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" - -namespace onnx_mlir { - -bool requiresPostRewrite(spatial::SpatCompute computeOp); - -bool requiresPostRewrite(spatial::SpatComputeBatch computeOp); - -void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - -void annotateWeightsConstants(mlir::func::FuncOp funcOp); - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp b/src/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp deleted file mode 100644 index 47085af..0000000 --- a/src/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp +++ /dev/null @@ -1,10 +0,0 @@ -#pragma once - -#include "mlir/IR/MLIRContext.h" -#include "mlir/Transforms/DialectConversion.h" - -namespace onnx_mlir { - -void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt index 1b56f12..00f11de 100644 --- a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt +++ b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt @@ -3,15 +3,17 @@ mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}") add_public_tablegen_target(SpatialToPimIncGen) add_pim_library(OMSpatialToPim + Patterns.cpp SpatialToPimPass.cpp BatchCoreLoweringPatterns.cpp - ChannelLoweringPatterns.cpp Common.cpp ComputeLikeRegionUtils.cpp CoreLoweringPatterns.cpp - GlobalTensorMaterialization.cpp ReturnPathNormalization.cpp - TensorPackingPatterns.cpp + Patterns/ChannelLowering.cpp + Patterns/GlobalTensorMaterialization.cpp + Patterns/TensorPacking.cpp + Patterns/Transpose.cpp EXCLUDE_FROM_OM_LIBS @@ -19,6 +21,7 @@ add_pim_library(OMSpatialToPim SpatialToPimIncGen LINK_LIBS PUBLIC + MLIRLinalgDialect MLIRSCFDialect MLIRSCFUtils MLIRTransformUtils diff --git a/src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp b/src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp deleted file mode 100644 index 068b2a3..0000000 --- a/src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once - -#include "mlir/IR/PatternMatch.h" - -namespace onnx_mlir { - -void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns); - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp index f62bfc8..2225b4c 100644 --- a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp @@ -1,5 +1,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/IRMapping.h" @@ -24,7 +25,7 @@ static bool isChannelUseChainOp(Operation* op) { tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp, - ONNXTransposeOp, + linalg::TransposeOp, pim::PimTransposeOp>(op); } diff --git a/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp b/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp deleted file mode 100644 index 7464dec..0000000 --- a/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once - -#include "mlir/IR/PatternMatch.h" - -namespace onnx_mlir { - -void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns); - -} diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.cpp b/src/PIM/Conversion/SpatialToPim/Patterns.cpp new file mode 100644 index 0000000..e619a0f --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/Patterns.cpp @@ -0,0 +1,40 @@ +#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace raptor { + +#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc" + +} // namespace raptor + +void populateInitialPatterns(RewritePatternSet& patterns) { + raptor::populateWithGenerated(patterns); + populateTransposeLoweringPatterns(patterns); +} + +void populateGlobalTensorMaterializationPatternPhase(RewritePatternSet& patterns) { + populateGlobalTensorMaterializationPatterns(patterns); +} + +void populateInitialTensorPackingPatterns(RewritePatternSet& patterns) { + populateTensorPackingPatterns(patterns); +} + +void populateCoreBodyPatterns(RewritePatternSet& patterns) { + raptor::populateWithGenerated(patterns); + populateTransposeLoweringPatterns(patterns); +} + +void populateFinalTensorPackingPatterns(RewritePatternSet& patterns) { + populateTensorPackingPatterns(patterns); +} + +void populateCommunicationPatterns(RewritePatternSet& patterns) { + populateChannelLoweringPatterns(patterns); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp b/src/PIM/Conversion/SpatialToPim/Patterns.hpp similarity index 64% rename from src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp rename to src/PIM/Conversion/SpatialToPim/Patterns.hpp index 73997dd..c1a7bad 100644 --- a/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns.hpp @@ -8,6 +8,18 @@ namespace onnx_mlir { +void populateInitialPatterns(mlir::RewritePatternSet& patterns); +void populateGlobalTensorMaterializationPatternPhase(mlir::RewritePatternSet& patterns); +void populateInitialTensorPackingPatterns(mlir::RewritePatternSet& patterns); +void populateCoreBodyPatterns(mlir::RewritePatternSet& patterns); +void populateFinalTensorPackingPatterns(mlir::RewritePatternSet& patterns); +void populateCommunicationPatterns(mlir::RewritePatternSet& patterns); + +void populateTransposeLoweringPatterns(mlir::RewritePatternSet& patterns); +void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns); +void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns); +void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns); + mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count); mlir::Value extractPackedChunk(mlir::Value packedValue, mlir::RankedTensorType chunkType, @@ -20,7 +32,6 @@ mlir::Value createPackedExtractRowsSlice(spatial::SpatExtractRowsOp extractRowsO mlir::OpBuilder& builder, mlir::Location loc); mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc); -void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns); void eraseUnusedTensorPackingOps(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter); } // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/Patterns/ChannelLowering.cpp similarity index 97% rename from src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.cpp rename to src/PIM/Conversion/SpatialToPim/Patterns/ChannelLowering.cpp index 042e473..a938bd8 100644 --- a/src/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns/ChannelLowering.cpp @@ -1,6 +1,6 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" diff --git a/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp b/src/PIM/Conversion/SpatialToPim/Patterns/GlobalTensorMaterialization.cpp similarity index 99% rename from src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp rename to src/PIM/Conversion/SpatialToPim/Patterns/GlobalTensorMaterialization.cpp index 4c0e4d5..3ac2289 100644 --- a/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns/GlobalTensorMaterialization.cpp @@ -16,7 +16,7 @@ #include "Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; diff --git a/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.cpp b/src/PIM/Conversion/SpatialToPim/Patterns/TensorPacking.cpp similarity index 99% rename from src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.cpp rename to src/PIM/Conversion/SpatialToPim/Patterns/TensorPacking.cpp index 9a65fc3..719b85c 100644 --- a/src/PIM/Conversion/SpatialToPim/TensorPackingPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns/TensorPacking.cpp @@ -1,4 +1,4 @@ -#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp" +#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" using namespace mlir; diff --git a/src/PIM/Conversion/SpatialToPim/Patterns/Transpose.cpp b/src/PIM/Conversion/SpatialToPim/Patterns/Transpose.cpp new file mode 100644 index 0000000..ccd478a --- /dev/null +++ b/src/PIM/Conversion/SpatialToPim/Patterns/Transpose.cpp @@ -0,0 +1,38 @@ +#include "mlir/Dialect/Linalg/IR/Linalg.h" + +#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp" +#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +struct LinalgTransposeToPim final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter& rewriter) const override { + SmallVector permutationAttrs; + permutationAttrs.reserve(transposeOp.getPermutation().size()); + for (int64_t dim : transposeOp.getPermutation()) + permutationAttrs.push_back(rewriter.getI64IntegerAttr(dim)); + + auto permutation = rewriter.getArrayAttr(permutationAttrs); + auto pimTranspose = pim::PimTransposeOp::create(rewriter, + transposeOp.getLoc(), + TypeRange {transposeOp->getResult(0).getType()}, + transposeOp.getInput(), + permutation, + transposeOp.getInit()); + rewriter.replaceOp(transposeOp, pimTranspose.getOutput()); + return success(); + } +}; + +} // namespace + +void populateTransposeLoweringPatterns(RewritePatternSet& patterns) { + patterns.add(patterns.getContext()); +} + +} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp index 26cd27f..a1376f5 100644 --- a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp +++ b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp @@ -1,5 +1,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -40,7 +41,7 @@ static bool isReturnHelperChainOp(Operation* op) { tensor::ExpandShapeOp, tensor::CastOp, tosa::ReshapeOp, - ONNXTransposeOp, + linalg::TransposeOp, pim::PimTransposeOp>(op); } @@ -276,11 +277,10 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef sourceIndice continue; } - if (auto transposeOp = dyn_cast(op)) { + if (auto transposeOp = dyn_cast(op)) { SmallVector nextIndices(currentIndices.size()); SmallVector nextShape(currentShape.size()); - for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange())) { - int64_t sourceIndex = attr.getInt(); + for (auto [destIndex, sourceIndex] : llvm::enumerate(transposeOp.getPermutation())) { nextIndices[destIndex] = currentIndices[sourceIndex]; nextShape[destIndex] = currentShape[sourceIndex]; } diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td index 20fc586..31a4e13 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPim.td +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPim.td @@ -9,12 +9,6 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td" include "src/Accelerators/PIM/Dialect/Pim/Pim.td" #endif // OP_BASE -def onnxToPimTranspose : Pat< - (ONNXTransposeOp:$srcOpRes $data, $perms), - (PimTransposeOp $data, $perms, - (NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes)) ->; - def spatToPimVMM : Pat< (SpatVMMOp:$srcOpRes $weight, $vector), (PimVMMOp $weight, $vector, diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 5c1ab93..b9de797 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -27,10 +27,8 @@ #include "Common/PimCommon.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp" -#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" #include "Conversion/SpatialToPim/Common.hpp" -#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp" -#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp" +#include "Conversion/SpatialToPim/Patterns.hpp" #include "Dialect/Pim/PimOps.hpp" #include "Dialect/Spatial/SpatialOps.hpp" #include "Pass/PIMPasses.h" @@ -41,11 +39,6 @@ using namespace onnx_mlir; using namespace pim; namespace onnx_mlir { -namespace raptor { - -#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc" - -} // namespace raptor static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) { auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType(); @@ -159,7 +152,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { spatial::SpatExtractRowsOp>(); RewritePatternSet initialPatterns(ctx); - populateWithGenerated(initialPatterns); + populateInitialPatterns(initialPatterns); if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) { moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form"); signalPassFailure(); @@ -167,7 +160,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { } RewritePatternSet globalTensorPatterns(ctx); - populateGlobalTensorMaterializationPatterns(globalTensorPatterns); + populateGlobalTensorMaterializationPatternPhase(globalTensorPatterns); walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns)); auto returnOp = cast(funcOp.front().getTerminator()); @@ -197,7 +190,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { } RewritePatternSet initialTensorPackingPatterns(ctx); - populateTensorPackingPatterns(initialTensorPackingPatterns); + populateInitialTensorPackingPatterns(initialTensorPackingPatterns); walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns)); eraseUnusedTensorPackingOps(funcOp, rewriter); @@ -214,7 +207,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { } RewritePatternSet coreBodyPatterns(ctx); - populateWithGenerated(coreBodyPatterns); + populateCoreBodyPatterns(coreBodyPatterns); populateAffineToStdConversionPatterns(coreBodyPatterns); FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns)); @@ -257,7 +250,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { eraseOpsToRemove(); RewritePatternSet finalTensorPackingPatterns(ctx); - populateTensorPackingPatterns(finalTensorPackingPatterns); + populateFinalTensorPackingPatterns(finalTensorPackingPatterns); walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns)); eraseUnusedTensorPackingOps(funcOp, rewriter); @@ -277,7 +270,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() { spatial::SpatExtractRowsOp>(); RewritePatternSet communicationPatterns(ctx); - populateChannelLoweringPatterns(communicationPatterns); + populateCommunicationPatterns(communicationPatterns); if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) { funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops"); signalPassFailure(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp index 70d4cd5..508de55 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp @@ -60,7 +60,7 @@ std::vector> buildReverseLevels(const ComputeGraph& graph) { } void verifyOctTableSize(size_t nodeCount, size_t processorCount) { - constexpr size_t kMaxOctTableBytes = 1ull << 30; + constexpr size_t kMaxOctTableBytes = 1ull << 35; if (nodeCount == 0 || processorCount == 0) return; if (processorCount > std::numeric_limits::max() / sizeof(Time))