Transpose and Refactor of Patterns
Validate Operations / validate-operations (push) Has been cancelled
Validate Operations / validate-operations (push) Has been cancelled
This commit is contained in:
@@ -18,6 +18,7 @@ add_pim_library(OMPimCommon
|
|||||||
${PIM_PUBLIC_INCLUDE_DIRS}
|
${PIM_PUBLIC_INCLUDE_DIRS}
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRLinalgDialect
|
||||||
onnx
|
onnx
|
||||||
SpatialOps
|
SpatialOps
|
||||||
PimOps
|
PimOps
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
|
||||||
@@ -131,8 +132,8 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
|
|||||||
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
|
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
|
||||||
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
|
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
|
||||||
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
|
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
|
||||||
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(user))
|
if (auto transposeOp = mlir::dyn_cast<mlir::linalg::TransposeOp>(user))
|
||||||
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
|
return transposeOp.getInput() == currentValue && self(transposeOp.getResult()[0], self);
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
|
|||||||
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
||||||
llvm::raw_os_ostream os(file);
|
llvm::raw_os_ostream os(file);
|
||||||
mlir::OpPrintingFlags flags;
|
mlir::OpPrintingFlags flags;
|
||||||
flags.elideLargeElementsAttrs();
|
flags.elideLargeElementsAttrs().enableDebugInfo(true,false);
|
||||||
moduleOp.print(os, flags);
|
moduleOp.print(os, flags);
|
||||||
os.flush();
|
os.flush();
|
||||||
file.close();
|
file.close();
|
||||||
|
|||||||
@@ -3,11 +3,12 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
|||||||
add_public_tablegen_target(ONNXToSpatialIncGen)
|
add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||||
|
|
||||||
add_pim_library(OMONNXToSpatial
|
add_pim_library(OMONNXToSpatial
|
||||||
ConversionPatterns.cpp
|
Patterns.cpp
|
||||||
CompileTime.cpp
|
CompileTime.cpp
|
||||||
ONNXToSpatialVerifier.cpp
|
ONNXToSpatialVerifier.cpp
|
||||||
PrePatterns.cpp
|
Patterns/Pre.cpp
|
||||||
PostPatterns.cpp
|
Patterns/Post.cpp
|
||||||
|
Patterns/GeneratedConversion.cpp
|
||||||
Patterns/Math/Conv.cpp
|
Patterns/Math/Conv.cpp
|
||||||
Patterns/Math/Elementwise.cpp
|
Patterns/Math/Elementwise.cpp
|
||||||
Patterns/Math/Gemm.cpp
|
Patterns/Math/Gemm.cpp
|
||||||
@@ -22,6 +23,7 @@ add_pim_library(OMONNXToSpatial
|
|||||||
Patterns/Tensor/Resize.cpp
|
Patterns/Tensor/Resize.cpp
|
||||||
Patterns/Tensor/Reshape.cpp
|
Patterns/Tensor/Reshape.cpp
|
||||||
Patterns/Tensor/Split.cpp
|
Patterns/Tensor/Split.cpp
|
||||||
|
Patterns/Tensor/Transpose.cpp
|
||||||
ONNXToSpatialPass.cpp
|
ONNXToSpatialPass.cpp
|
||||||
Common/ComputeRegionBuilder.cpp
|
Common/ComputeRegionBuilder.cpp
|
||||||
Common/ShapeTilingUtils.cpp
|
Common/ShapeTilingUtils.cpp
|
||||||
@@ -33,6 +35,7 @@ add_pim_library(OMONNXToSpatial
|
|||||||
ONNXToSpatialIncGen
|
ONNXToSpatialIncGen
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRLinalgDialect
|
||||||
MLIRSCFDialect
|
MLIRSCFDialect
|
||||||
MLIRTosaDialect
|
MLIRTosaDialect
|
||||||
OMCompilerOptions
|
OMCompilerOptions
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
@@ -43,8 +44,8 @@ bool isWeightLikeComputeOperand(Value value) {
|
|||||||
value = collapseShapeOp.getSrc();
|
value = collapseShapeOp.getSrc();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
|
||||||
value = transposeOp.getData();
|
value = transposeOp.getInput();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,7 +81,7 @@ FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr
|
|||||||
return referencedValue.getResult();
|
return referencedValue.getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(definingOp))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
IRMapping localMapper;
|
IRMapping localMapper;
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
@@ -171,6 +172,16 @@ static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm:
|
|||||||
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
|
||||||
|
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getInput(), visited);
|
||||||
|
if (!inputAttr)
|
||||||
|
return nullptr;
|
||||||
|
|
||||||
|
SmallVector<int64_t> perm(transposeOp.getPermutation().begin(), transposeOp.getPermutation().end());
|
||||||
|
auto transposedAttr = transposeDenseElements(inputAttr, perm);
|
||||||
|
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
||||||
if (!inputAttr)
|
if (!inputAttr)
|
||||||
@@ -226,6 +237,9 @@ getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visit
|
|||||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
|
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
|
||||||
return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength);
|
return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength);
|
||||||
|
|
||||||
|
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(op))
|
||||||
|
return getCompileTimeSourceImpl(transposeOp.getInput().getDefiningOp(), visited, chainLength);
|
||||||
|
|
||||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
|
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
|
||||||
return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength);
|
return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength);
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
@@ -14,10 +15,8 @@
|
|||||||
#include "Common/Common.hpp"
|
#include "Common/Common.hpp"
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -86,30 +85,6 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
|||||||
returnOp.setOperand(index, computeResult);
|
returnOp.setOperand(index, computeResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
|
|
||||||
IRRewriter rewriter(funcOp.getContext());
|
|
||||||
Block& entryBlock = funcOp.getFunctionBody().front();
|
|
||||||
|
|
||||||
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
|
|
||||||
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
|
|
||||||
if (!transposeOp || isCompileTimeOp(transposeOp))
|
|
||||||
continue;
|
|
||||||
|
|
||||||
// Transpose stays globally legal because constant/view-only cases are
|
|
||||||
// allowed on the host. Any residual runtime transpose must be sunk into
|
|
||||||
// spat.compute before the host legality check.
|
|
||||||
auto resultType = transposeOp.getResult().getType();
|
|
||||||
rewriter.setInsertionPoint(transposeOp);
|
|
||||||
auto computeOp = createSpatCompute<1>(
|
|
||||||
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
|
|
||||||
Value transposed =
|
|
||||||
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
|
|
||||||
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
|
|
||||||
});
|
|
||||||
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ONNXToSpatialPass::runOnOperation() {
|
void ONNXToSpatialPass::runOnOperation() {
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
MLIRContext* ctx = &getContext();
|
MLIRContext* ctx = &getContext();
|
||||||
@@ -117,6 +92,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
ConversionTarget preTarget(*ctx);
|
ConversionTarget preTarget(*ctx);
|
||||||
preTarget.addLegalDialect<spatial::SpatialDialect,
|
preTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
ONNXDialect,
|
ONNXDialect,
|
||||||
|
linalg::LinalgDialect,
|
||||||
tensor::TensorDialect,
|
tensor::TensorDialect,
|
||||||
affine::AffineDialect,
|
affine::AffineDialect,
|
||||||
arith::ArithDialect,
|
arith::ArithDialect,
|
||||||
@@ -156,11 +132,13 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
ConversionTarget target(*ctx);
|
ConversionTarget target(*ctx);
|
||||||
target.addLegalDialect<spatial::SpatialDialect,
|
target.addLegalDialect<spatial::SpatialDialect,
|
||||||
ONNXDialect,
|
ONNXDialect,
|
||||||
|
linalg::LinalgDialect,
|
||||||
tensor::TensorDialect,
|
tensor::TensorDialect,
|
||||||
affine::AffineDialect,
|
affine::AffineDialect,
|
||||||
arith::ArithDialect,
|
arith::ArithDialect,
|
||||||
scf::SCFDialect>();
|
scf::SCFDialect>();
|
||||||
target.addIllegalOp<ONNXMatMulOp>();
|
target.addIllegalOp<ONNXMatMulOp>();
|
||||||
|
target.addIllegalOp<ONNXTransposeOp>();
|
||||||
target.addIllegalOp<ONNXAddOp>();
|
target.addIllegalOp<ONNXAddOp>();
|
||||||
target.addIllegalOp<ONNXDivOp>();
|
target.addIllegalOp<ONNXDivOp>();
|
||||||
target.addIllegalOp<ONNXMulOp>();
|
target.addIllegalOp<ONNXMulOp>();
|
||||||
@@ -187,9 +165,14 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RewritePatternSet transposePatterns(ctx);
|
||||||
|
populateTransposePatterns(transposePatterns, ctx);
|
||||||
|
walkAndApplyPatterns(moduleOp, std::move(transposePatterns));
|
||||||
|
|
||||||
ConversionTarget earlyPostTarget(*ctx);
|
ConversionTarget earlyPostTarget(*ctx);
|
||||||
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
ONNXDialect,
|
ONNXDialect,
|
||||||
|
linalg::LinalgDialect,
|
||||||
tensor::TensorDialect,
|
tensor::TensorDialect,
|
||||||
affine::AffineDialect,
|
affine::AffineDialect,
|
||||||
arith::ArithDialect,
|
arith::ArithDialect,
|
||||||
@@ -205,6 +188,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
ConversionTarget postTarget(*ctx);
|
ConversionTarget postTarget(*ctx);
|
||||||
postTarget.addLegalDialect<spatial::SpatialDialect,
|
postTarget.addLegalDialect<spatial::SpatialDialect,
|
||||||
ONNXDialect,
|
ONNXDialect,
|
||||||
|
linalg::LinalgDialect,
|
||||||
tensor::TensorDialect,
|
tensor::TensorDialect,
|
||||||
affine::AffineDialect,
|
affine::AffineDialect,
|
||||||
arith::ArithDialect,
|
arith::ArithDialect,
|
||||||
@@ -222,8 +206,6 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
wrapTopLevelRuntimeTransposes(*entryFunc);
|
|
||||||
|
|
||||||
if (failed(verifyONNXToSpatial(*entryFunc))) {
|
if (failed(verifyONNXToSpatial(*entryFunc))) {
|
||||||
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
|||||||
+11
-9
@@ -1,19 +1,16 @@
|
|||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
namespace {
|
void populatePrePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
|
populateGeneratedPrePatterns(patterns, ctx);
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
}
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
|
|
||||||
patterns.add<removeLRN>(ctx);
|
|
||||||
|
|
||||||
|
void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
|
populateGeneratedConversionPatterns(patterns, ctx);
|
||||||
populateElementwisePatterns(patterns, ctx);
|
populateElementwisePatterns(patterns, ctx);
|
||||||
populateGemmPatterns(patterns, ctx);
|
populateGemmPatterns(patterns, ctx);
|
||||||
populateConvPatterns(patterns, ctx);
|
populateConvPatterns(patterns, ctx);
|
||||||
@@ -27,6 +24,11 @@ void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRCon
|
|||||||
populateResizePatterns(patterns, ctx);
|
populateResizePatterns(patterns, ctx);
|
||||||
populateReshapePatterns(patterns, ctx);
|
populateReshapePatterns(patterns, ctx);
|
||||||
populateSplitPatterns(patterns, ctx);
|
populateSplitPatterns(patterns, ctx);
|
||||||
|
populateTransposePatterns(patterns, ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
|
populateWeightPromotionPatterns(patterns, ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
+14
-13
@@ -1,38 +1,39 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
|
void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
void populateGeneratedConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
void populateWeightPromotionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateReduceMeanPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateReduceMeanPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateSigmoidPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateSigmoidPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateSoftmaxPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateSoftmaxPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||||
|
|
||||||
|
bool requiresPostRewrite(spatial::SpatCompute computeOp);
|
||||||
|
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
|
||||||
|
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateGeneratedConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
|
patterns.add<removeLRN>(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
+4
-3
@@ -1,5 +1,6 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
@@ -9,7 +10,7 @@
|
|||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
@@ -20,7 +21,7 @@ namespace onnx_mlir {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
static bool isWeightMaterializationHelperUser(Operation* op) {
|
static bool isWeightMaterializationHelperUser(Operation* op) {
|
||||||
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(op);
|
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool canPromoteInputBlockArgument(BlockArgument arg) {
|
static bool canPromoteInputBlockArgument(BlockArgument arg) {
|
||||||
@@ -276,7 +277,7 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
void populateWeightPromotionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
|
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
+2
-3
@@ -1,6 +1,5 @@
|
|||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -12,7 +11,7 @@ namespace {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
|
void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
|
||||||
patterns.add<onnxToArithConstant>(ctx);
|
patterns.add<onnxToArithConstant>(ctx);
|
||||||
patterns.add<convAddToConvWithBiasLeft>(ctx);
|
patterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||||
patterns.add<convAddToConvWithBiasRight>(ctx);
|
patterns.add<convAddToConvWithBiasRight>(ctx);
|
||||||
@@ -6,7 +6,7 @@
|
|||||||
#include "llvm/ADT/SmallVector.h"
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,75 @@
|
|||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
static Value createTransposeInit(Value input,
|
||||||
|
RankedTensorType resultType,
|
||||||
|
ArrayRef<int64_t> permutation,
|
||||||
|
ConversionPatternRewriter& rewriter,
|
||||||
|
Location loc) {
|
||||||
|
SmallVector<OpFoldResult> sizes;
|
||||||
|
sizes.reserve(resultType.getRank());
|
||||||
|
for (auto [resultDim, sourceDim] : llvm::zip_equal(resultType.getShape(), permutation)) {
|
||||||
|
if (!ShapedType::isDynamic(resultDim)) {
|
||||||
|
sizes.push_back(rewriter.getIndexAttr(resultDim));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
sizes.push_back(tensor::DimOp::create(rewriter, loc, input, sourceDim).getResult());
|
||||||
|
}
|
||||||
|
return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<int64_t> getTransposePermutation(ONNXTransposeOp transposeOp) {
|
||||||
|
auto inputType = cast<RankedTensorType>(transposeOp.getData().getType());
|
||||||
|
SmallVector<int64_t> permutation;
|
||||||
|
if (auto permAttr = transposeOp.getPermAttr()) {
|
||||||
|
permutation.reserve(permAttr.size());
|
||||||
|
for (IntegerAttr attr : permAttr.getAsRange<IntegerAttr>())
|
||||||
|
permutation.push_back(attr.getInt());
|
||||||
|
return permutation;
|
||||||
|
}
|
||||||
|
|
||||||
|
permutation.reserve(inputType.getRank());
|
||||||
|
for (int64_t dim = inputType.getRank() - 1; dim >= 0; --dim)
|
||||||
|
permutation.push_back(dim);
|
||||||
|
return permutation;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ONNXTransposeOp transposeOp,
|
||||||
|
ONNXTransposeOpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const override {
|
||||||
|
auto inputType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
|
||||||
|
auto resultType = dyn_cast<RankedTensorType>(transposeOp.getResult().getType());
|
||||||
|
if (!inputType || !resultType)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
SmallVector<int64_t> permutation = getTransposePermutation(transposeOp);
|
||||||
|
Value init = createTransposeInit(adaptor.getData(), resultType, permutation, rewriter, transposeOp.getLoc());
|
||||||
|
Value transposed =
|
||||||
|
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, permutation)
|
||||||
|
.getResult()[0];
|
||||||
|
rewriter.replaceOp(transposeOp, transposed);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateTransposePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||||
|
patterns.add<TransposeToLinalgTranspose>(ctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
||||||
#include "mlir/IR/MLIRContext.h"
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
bool requiresPostRewrite(spatial::SpatCompute computeOp);
|
|
||||||
|
|
||||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
|
|
||||||
|
|
||||||
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
|
|
||||||
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/MLIRContext.h"
|
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -3,15 +3,17 @@ mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
|||||||
add_public_tablegen_target(SpatialToPimIncGen)
|
add_public_tablegen_target(SpatialToPimIncGen)
|
||||||
|
|
||||||
add_pim_library(OMSpatialToPim
|
add_pim_library(OMSpatialToPim
|
||||||
|
Patterns.cpp
|
||||||
SpatialToPimPass.cpp
|
SpatialToPimPass.cpp
|
||||||
BatchCoreLoweringPatterns.cpp
|
BatchCoreLoweringPatterns.cpp
|
||||||
ChannelLoweringPatterns.cpp
|
|
||||||
Common.cpp
|
Common.cpp
|
||||||
ComputeLikeRegionUtils.cpp
|
ComputeLikeRegionUtils.cpp
|
||||||
CoreLoweringPatterns.cpp
|
CoreLoweringPatterns.cpp
|
||||||
GlobalTensorMaterialization.cpp
|
|
||||||
ReturnPathNormalization.cpp
|
ReturnPathNormalization.cpp
|
||||||
TensorPackingPatterns.cpp
|
Patterns/ChannelLowering.cpp
|
||||||
|
Patterns/GlobalTensorMaterialization.cpp
|
||||||
|
Patterns/TensorPacking.cpp
|
||||||
|
Patterns/Transpose.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
@@ -19,6 +21,7 @@ add_pim_library(OMSpatialToPim
|
|||||||
SpatialToPimIncGen
|
SpatialToPimIncGen
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRLinalgDialect
|
||||||
MLIRSCFDialect
|
MLIRSCFDialect
|
||||||
MLIRSCFUtils
|
MLIRSCFUtils
|
||||||
MLIRTransformUtils
|
MLIRTransformUtils
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
|
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
@@ -24,7 +25,7 @@ static bool isChannelUseChainOp(Operation* op) {
|
|||||||
tensor::ExpandShapeOp,
|
tensor::ExpandShapeOp,
|
||||||
tensor::CastOp,
|
tensor::CastOp,
|
||||||
tosa::ReshapeOp,
|
tosa::ReshapeOp,
|
||||||
ONNXTransposeOp,
|
linalg::TransposeOp,
|
||||||
pim::PimTransposeOp>(op);
|
pim::PimTransposeOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
|
|
||||||
namespace onnx_mlir {
|
|
||||||
|
|
||||||
void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns);
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace raptor {
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
||||||
|
|
||||||
|
} // namespace raptor
|
||||||
|
|
||||||
|
void populateInitialPatterns(RewritePatternSet& patterns) {
|
||||||
|
raptor::populateWithGenerated(patterns);
|
||||||
|
populateTransposeLoweringPatterns(patterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
void populateGlobalTensorMaterializationPatternPhase(RewritePatternSet& patterns) {
|
||||||
|
populateGlobalTensorMaterializationPatterns(patterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
void populateInitialTensorPackingPatterns(RewritePatternSet& patterns) {
|
||||||
|
populateTensorPackingPatterns(patterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
|
||||||
|
raptor::populateWithGenerated(patterns);
|
||||||
|
populateTransposeLoweringPatterns(patterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
void populateFinalTensorPackingPatterns(RewritePatternSet& patterns) {
|
||||||
|
populateTensorPackingPatterns(patterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
void populateCommunicationPatterns(RewritePatternSet& patterns) {
|
||||||
|
populateChannelLoweringPatterns(patterns);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
+12
-1
@@ -8,6 +8,18 @@
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
|
void populateInitialPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
void populateGlobalTensorMaterializationPatternPhase(mlir::RewritePatternSet& patterns);
|
||||||
|
void populateInitialTensorPackingPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
void populateCoreBodyPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
void populateFinalTensorPackingPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
void populateCommunicationPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
|
||||||
|
void populateTransposeLoweringPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
|
||||||
|
|
||||||
mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count);
|
mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count);
|
||||||
mlir::Value extractPackedChunk(mlir::Value packedValue,
|
mlir::Value extractPackedChunk(mlir::Value packedValue,
|
||||||
mlir::RankedTensorType chunkType,
|
mlir::RankedTensorType chunkType,
|
||||||
@@ -20,7 +32,6 @@ mlir::Value createPackedExtractRowsSlice(spatial::SpatExtractRowsOp extractRowsO
|
|||||||
mlir::OpBuilder& builder,
|
mlir::OpBuilder& builder,
|
||||||
mlir::Location loc);
|
mlir::Location loc);
|
||||||
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
|
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
|
||||||
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
|
|
||||||
void eraseUnusedTensorPackingOps(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
void eraseUnusedTensorPackingOps(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
+1
-1
@@ -1,6 +1,6 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
+1
-1
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
+1
-1
@@ -1,4 +1,4 @@
|
|||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
|
||||||
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
namespace onnx_mlir {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct LinalgTransposeToPim final : OpRewritePattern<linalg::TransposeOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter& rewriter) const override {
|
||||||
|
SmallVector<Attribute> permutationAttrs;
|
||||||
|
permutationAttrs.reserve(transposeOp.getPermutation().size());
|
||||||
|
for (int64_t dim : transposeOp.getPermutation())
|
||||||
|
permutationAttrs.push_back(rewriter.getI64IntegerAttr(dim));
|
||||||
|
|
||||||
|
auto permutation = rewriter.getArrayAttr(permutationAttrs);
|
||||||
|
auto pimTranspose = pim::PimTransposeOp::create(rewriter,
|
||||||
|
transposeOp.getLoc(),
|
||||||
|
TypeRange {transposeOp->getResult(0).getType()},
|
||||||
|
transposeOp.getInput(),
|
||||||
|
permutation,
|
||||||
|
transposeOp.getInit());
|
||||||
|
rewriter.replaceOp(transposeOp, pimTranspose.getOutput());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void populateTransposeLoweringPatterns(RewritePatternSet& patterns) {
|
||||||
|
patterns.add<LinalgTransposeToPim>(patterns.getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace onnx_mlir
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||||
@@ -40,7 +41,7 @@ static bool isReturnHelperChainOp(Operation* op) {
|
|||||||
tensor::ExpandShapeOp,
|
tensor::ExpandShapeOp,
|
||||||
tensor::CastOp,
|
tensor::CastOp,
|
||||||
tosa::ReshapeOp,
|
tosa::ReshapeOp,
|
||||||
ONNXTransposeOp,
|
linalg::TransposeOp,
|
||||||
pim::PimTransposeOp>(op);
|
pim::PimTransposeOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -276,11 +277,10 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op)) {
|
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(op)) {
|
||||||
SmallVector<int64_t> nextIndices(currentIndices.size());
|
SmallVector<int64_t> nextIndices(currentIndices.size());
|
||||||
SmallVector<int64_t> nextShape(currentShape.size());
|
SmallVector<int64_t> nextShape(currentShape.size());
|
||||||
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange<IntegerAttr>())) {
|
for (auto [destIndex, sourceIndex] : llvm::enumerate(transposeOp.getPermutation())) {
|
||||||
int64_t sourceIndex = attr.getInt();
|
|
||||||
nextIndices[destIndex] = currentIndices[sourceIndex];
|
nextIndices[destIndex] = currentIndices[sourceIndex];
|
||||||
nextShape[destIndex] = currentShape[sourceIndex];
|
nextShape[destIndex] = currentShape[sourceIndex];
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,12 +9,6 @@ include "src/Accelerators/PIM/Dialect/Spatial/Spatial.td"
|
|||||||
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
include "src/Accelerators/PIM/Dialect/Pim/Pim.td"
|
||||||
#endif // OP_BASE
|
#endif // OP_BASE
|
||||||
|
|
||||||
def onnxToPimTranspose : Pat<
|
|
||||||
(ONNXTransposeOp:$srcOpRes $data, $perms),
|
|
||||||
(PimTransposeOp $data, $perms,
|
|
||||||
(NativeCodeCall<"onnx_mlir::getBestOutputTensorFromOperandsOrAllocate($_builder, $0.getDefiningOp())"> $srcOpRes))
|
|
||||||
>;
|
|
||||||
|
|
||||||
def spatToPimVMM : Pat<
|
def spatToPimVMM : Pat<
|
||||||
(SpatVMMOp:$srcOpRes $weight, $vector),
|
(SpatVMMOp:$srcOpRes $weight, $vector),
|
||||||
(PimVMMOp $weight, $vector,
|
(PimVMMOp $weight, $vector,
|
||||||
|
|||||||
@@ -27,10 +27,8 @@
|
|||||||
|
|
||||||
#include "Common/PimCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
|
||||||
#include "Conversion/SpatialToPim/Common.hpp"
|
#include "Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
#include "Conversion/SpatialToPim/Patterns.hpp"
|
||||||
#include "Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
|
||||||
#include "Dialect/Pim/PimOps.hpp"
|
#include "Dialect/Pim/PimOps.hpp"
|
||||||
#include "Dialect/Spatial/SpatialOps.hpp"
|
#include "Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "Pass/PIMPasses.h"
|
#include "Pass/PIMPasses.h"
|
||||||
@@ -41,11 +39,6 @@ using namespace onnx_mlir;
|
|||||||
using namespace pim;
|
using namespace pim;
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace raptor {
|
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
|
||||||
|
|
||||||
} // namespace raptor
|
|
||||||
|
|
||||||
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
static memref::GlobalOp getOrCreateZeroGlobal(IRRewriter& rewriter, Location loc, RankedTensorType tensorType) {
|
||||||
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
auto moduleOp = rewriter.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
|
||||||
@@ -159,7 +152,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|||||||
spatial::SpatExtractRowsOp>();
|
spatial::SpatExtractRowsOp>();
|
||||||
|
|
||||||
RewritePatternSet initialPatterns(ctx);
|
RewritePatternSet initialPatterns(ctx);
|
||||||
populateWithGenerated(initialPatterns);
|
populateInitialPatterns(initialPatterns);
|
||||||
if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) {
|
if (failed(applyPartialConversion(moduleOp, target, std::move(initialPatterns)))) {
|
||||||
moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form");
|
moduleOp.emitError("failed to lower required Spatial ops to the initial PIM form");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
@@ -167,7 +160,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
RewritePatternSet globalTensorPatterns(ctx);
|
RewritePatternSet globalTensorPatterns(ctx);
|
||||||
populateGlobalTensorMaterializationPatterns(globalTensorPatterns);
|
populateGlobalTensorMaterializationPatternPhase(globalTensorPatterns);
|
||||||
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
|
walkAndApplyPatterns(moduleOp, std::move(globalTensorPatterns));
|
||||||
|
|
||||||
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
auto returnOp = cast<func::ReturnOp>(funcOp.front().getTerminator());
|
||||||
@@ -197,7 +190,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
RewritePatternSet initialTensorPackingPatterns(ctx);
|
RewritePatternSet initialTensorPackingPatterns(ctx);
|
||||||
populateTensorPackingPatterns(initialTensorPackingPatterns);
|
populateInitialTensorPackingPatterns(initialTensorPackingPatterns);
|
||||||
walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns));
|
walkAndApplyPatterns(funcOp, std::move(initialTensorPackingPatterns));
|
||||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||||
|
|
||||||
@@ -214,7 +207,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
RewritePatternSet coreBodyPatterns(ctx);
|
RewritePatternSet coreBodyPatterns(ctx);
|
||||||
populateWithGenerated(coreBodyPatterns);
|
populateCoreBodyPatterns(coreBodyPatterns);
|
||||||
populateAffineToStdConversionPatterns(coreBodyPatterns);
|
populateAffineToStdConversionPatterns(coreBodyPatterns);
|
||||||
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
|
FrozenRewritePatternSet frozenCoreBodyPatterns(std::move(coreBodyPatterns));
|
||||||
|
|
||||||
@@ -257,7 +250,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|||||||
eraseOpsToRemove();
|
eraseOpsToRemove();
|
||||||
|
|
||||||
RewritePatternSet finalTensorPackingPatterns(ctx);
|
RewritePatternSet finalTensorPackingPatterns(ctx);
|
||||||
populateTensorPackingPatterns(finalTensorPackingPatterns);
|
populateFinalTensorPackingPatterns(finalTensorPackingPatterns);
|
||||||
walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns));
|
walkAndApplyPatterns(funcOp, std::move(finalTensorPackingPatterns));
|
||||||
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
eraseUnusedTensorPackingOps(funcOp, rewriter);
|
||||||
|
|
||||||
@@ -277,7 +270,7 @@ void onnx_mlir::raptor::SpatialToPimPass::runOnOperation() {
|
|||||||
spatial::SpatExtractRowsOp>();
|
spatial::SpatExtractRowsOp>();
|
||||||
|
|
||||||
RewritePatternSet communicationPatterns(ctx);
|
RewritePatternSet communicationPatterns(ctx);
|
||||||
populateChannelLoweringPatterns(communicationPatterns);
|
populateCommunicationPatterns(communicationPatterns);
|
||||||
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
|
if (failed(applyFullConversion(funcOp, communicationTarget, std::move(communicationPatterns)))) {
|
||||||
funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops");
|
funcOp.emitOpError("failed to lower Spatial communication ops to PIM communication ops");
|
||||||
signalPassFailure();
|
signalPassFailure();
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
|
void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
|
||||||
constexpr size_t kMaxOctTableBytes = 1ull << 30;
|
constexpr size_t kMaxOctTableBytes = 1ull << 35;
|
||||||
if (nodeCount == 0 || processorCount == 0)
|
if (nodeCount == 0 || processorCount == 0)
|
||||||
return;
|
return;
|
||||||
if (processorCount > std::numeric_limits<size_t>::max() / sizeof(Time))
|
if (processorCount > std::numeric_limits<size_t>::max() / sizeof(Time))
|
||||||
|
|||||||
Reference in New Issue
Block a user