replace deprecated "rewriter.create()" pattern

refactor PIM to Pim everywhere except for the accelerator name
This commit is contained in:
NiccoloN
2026-03-20 13:30:53 +01:00
parent 916a09414c
commit bb6dcd38a3
32 changed files with 222 additions and 212 deletions

View File

@@ -43,7 +43,7 @@ add_onnx_mlir_library(OMPIMAccel
PimOps PimOps
OMONNXToSpatial OMONNXToSpatial
OMSpatialToGraphviz OMSpatialToGraphviz
OMSpatialToPIM OMSpatialToPim
OMPIMCommon OMPimCommon
MLIRTensorInferTypeOpInterfaceImpl MLIRTensorInferTypeOpInterfaceImpl
) )

View File

@@ -1,5 +1,5 @@
add_onnx_mlir_library(OMPIMCommon add_onnx_mlir_library(OMPimCommon
PIMCommon.cpp PimCommon.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS

View File

@@ -3,10 +3,10 @@
#include <filesystem> #include <filesystem>
#include <fstream> #include <fstream>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
@@ -37,8 +37,7 @@ FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>()); SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
if (entryPoints.size() > 1) { if (entryPoints.size() > 1) {
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
<< entryPoints.size();
return failure(); return failure();
} }
if (!entryPoints.empty()) { if (!entryPoints.empty()) {
@@ -61,10 +60,9 @@ FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
return mainGraphFunc; return mainGraphFunc;
SmallVector<func::FuncOp> nonExternalFuncs; SmallVector<func::FuncOp> nonExternalFuncs;
for (auto funcOp : moduleOp.getOps<func::FuncOp>()) { for (auto funcOp : moduleOp.getOps<func::FuncOp>())
if (!funcOp.isExternal()) if (!funcOp.isExternal())
nonExternalFuncs.push_back(funcOp); nonExternalFuncs.push_back(funcOp);
}
if (nonExternalFuncs.size() == 1) if (nonExternalFuncs.size() == 1)
return nonExternalFuncs.front(); return nonExternalFuncs.front();
@@ -72,11 +70,11 @@ FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
return failure(); return failure();
} }
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PIM_WEIGHT_ALWAYS_ATTR_NAME) != nullptr; } bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
void markWeightAlways(Operation* op) { void markWeightAlways(Operation* op) {
assert(op && "expected valid op"); assert(op && "expected valid op");
op->setAttr(PIM_WEIGHT_ALWAYS_ATTR_NAME, UnitAttr::get(op->getContext())); op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
} }
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) { memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {

View File

@@ -1,8 +1,8 @@
#pragma once #pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Operation.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
@@ -10,8 +10,8 @@
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = "pim.constant.should_allocate"; const llvm::StringRef PimConstantShouldAllocateAttrName = "pim.constant.should_allocate";
inline constexpr llvm::StringRef PIM_WEIGHT_ALWAYS_ATTR_NAME = "weightAlways"; inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir { namespace onnx_mlir {

View File

@@ -13,9 +13,9 @@
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include "Common/PIMCommon.hpp" #include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "Conversion/SpatialToPIM/SpatialToPIMCommon.hpp" #include "Conversion/SpatialToPim/SpatialToPimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp" #include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"

View File

@@ -34,7 +34,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
} }
if (pimEmissionTarget >= EmitPim) { if (pimEmissionTarget >= EmitPim) {
pm.addPass(createSpatialToPIMPass()); pm.addPass(createSpatialToPimPass());
// pm.addPass(createCountInstructionPass()); // pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Spatial lowered to Pim")); pm.addPass(createMessagePass("Spatial lowered to Pim"));
} }

View File

@@ -1,3 +1,3 @@
add_subdirectory(ONNXToSpatial) add_subdirectory(ONNXToSpatial)
add_subdirectory(SpatialToGraphviz) add_subdirectory(SpatialToGraphviz)
add_subdirectory(SpatialToPIM) add_subdirectory(SpatialToPim)

View File

@@ -23,7 +23,7 @@ add_onnx_mlir_library(OMONNXToSpatial
OMPimCompilerOptions OMPimCompilerOptions
OMONNXOps OMONNXOps
SpatialOps SpatialOps
OMPIMCommon OMPimCommon
ACCEL_INCLUDE_DIRS PRIVATE ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH} ${PIM_INCLUDE_PATH}

View File

@@ -10,7 +10,7 @@
#include <cassert> #include <cassert>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -91,7 +91,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))}; SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType()); auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
auto aSlice = rewriter.create<tensor::ExtractSliceOp>(loc, aSliceType, a, offsets, sizes, strides).getResult(); auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
Value cSlice = c; Value cSlice = c;
if (hasC) { if (hasC) {
@@ -100,13 +100,14 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))}; SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType()); auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
cSlice = rewriter.create<tensor::ExtractSliceOp>(loc, cSliceType, c, offsets, sizes, strides).getResult(); cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
} }
else else
assert("C should be a vector" && isVectorShape(getTensorShape(c))); assert("C should be a vector" && isVectorShape(getTensorShape(c)));
} }
auto gemvOp = rewriter.create<ONNXGemmOp>(loc, auto gemvOp = ONNXGemmOp::create(rewriter,
loc,
outRowType, outRowType,
aSlice, aSlice,
b, b,
@@ -119,7 +120,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
} }
auto concatComputeOp = auto concatComputeOp =
rewriter.create<spatial::SpatWeightedCompute>(loc, gemmOp.getType(), SmallVector<Value>(), gemvOps); spatial::SpatWeightedCompute::create(rewriter, loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
auto* concatBlock = new Block(); auto* concatBlock = new Block();
for (auto gemvOp : gemvOps) for (auto gemvOp : gemvOps)
@@ -128,8 +129,8 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
rewriter.setInsertionPointToStart(concatBlock); rewriter.setInsertionPointToStart(concatBlock);
auto blockArgs = concatBlock->getArguments(); auto blockArgs = concatBlock->getArguments();
auto concatOp = rewriter.create<tensor::ConcatOp>(loc, /*axis=*/0, blockArgs); auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, blockArgs);
rewriter.create<spatial::SpatYieldOp>(loc, concatOp.getResult()); spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
rewriter.replaceOp(gemmOp, concatComputeOp); rewriter.replaceOp(gemmOp, concatComputeOp);
return success(); return success();
@@ -170,25 +171,25 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
if (transA) { if (transA) {
auto aShape = aType.getShape(); auto aShape = aType.getShape();
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType()); auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
a = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0})); a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
} }
if (transB) { if (transB) {
auto bShape = bType.getShape(); auto bShape = bType.getShape();
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType()); auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
b = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0})); b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
} }
if (alpha != 1.0f) { if (alpha != 1.0f) {
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType()); auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha}); auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
auto alphaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, alphaTensorType, alphaTensorValue); auto alphaTensor = arith::ConstantOp::create(rewriter, gemmLoc, alphaTensorType, alphaTensorValue);
a = rewriter.create<spatial::SpatVMulOp>(gemmLoc, a.getType(), a, alphaTensor); a = spatial::SpatVMulOp::create(rewriter, gemmLoc, a.getType(), a, alphaTensor);
} }
if (hasC && beta != 1.0f) { if (hasC && beta != 1.0f) {
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType()); auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta}); auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
auto betaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, betaTensorType, betaTensorValue); auto betaTensor = arith::ConstantOp::create(rewriter, gemmLoc, betaTensorType, betaTensorValue);
c = rewriter.create<spatial::SpatVMulOp>(gemmLoc, c.getType(), c, betaTensor); c = spatial::SpatVMulOp::create(rewriter, gemmLoc, c.getType(), c, betaTensor);
} }
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue()); auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
@@ -235,7 +236,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
weights.push_back(bTiles[outSliceId][coreId][aSliceId]); weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
auto computeOp = auto computeOp =
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, weights, aHSlices[coreId]); spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId]);
auto* computeBlock = new Block(); auto* computeBlock = new Block();
for (auto aHSlice : aHSlices[coreId]) for (auto aHSlice : aHSlices[coreId])
@@ -248,11 +249,11 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
vmmOutputs.reserve(computeArgs.size()); vmmOutputs.reserve(computeArgs.size());
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++) for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++)
vmmOutputs.push_back( vmmOutputs.push_back(
rewriter.create<spatial::SpatWeightedVMMOp>(gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId])); spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId]));
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty"); assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
Value partialVmmSum = sumTensors(vmmOutputs, rewriter); Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
rewriter.create<spatial::SpatYieldOp>(gemmLoc, partialVmmSum); spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
partialResults.push_back(computeOp.getResult(0)); partialResults.push_back(computeOp.getResult(0));
@@ -264,7 +265,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
} }
auto reduceComputeOp = auto reduceComputeOp =
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults); spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults);
auto* reduceBlock = new Block(); auto* reduceBlock = new Block();
for (auto partialResult : partialResults) for (auto partialResult : partialResults)
@@ -274,14 +275,14 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
auto blockArgs = reduceBlock->getArguments(); auto blockArgs = reduceBlock->getArguments();
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter); Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
rewriter.create<spatial::SpatYieldOp>(gemmLoc, outHSlice); spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice);
rewriter.setInsertionPointAfter(reduceComputeOp); rewriter.setInsertionPointAfter(reduceComputeOp);
outHSlices.push_back(reduceComputeOp.getResult(0)); outHSlices.push_back(reduceComputeOp.getResult(0));
} }
auto concatComputeOp = auto concatComputeOp =
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices); spatial::SpatWeightedCompute::create(rewriter, gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
auto* concatBlock = new Block(); auto* concatBlock = new Block();
for (auto outHSlice : outHSlices) for (auto outHSlice : outHSlices)
@@ -290,8 +291,8 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
rewriter.setInsertionPointToStart(concatBlock); rewriter.setInsertionPointToStart(concatBlock);
auto blockArgs = concatBlock->getArguments(); auto blockArgs = concatBlock->getArguments();
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, blockArgs); auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
rewriter.create<spatial::SpatYieldOp>(gemmLoc, concatOp.getResult()); spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
rewriter.replaceOp(gemmOp, concatComputeOp); rewriter.replaceOp(gemmOp, concatComputeOp);
return success(); return success();
@@ -335,9 +336,9 @@ LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVector<OpAn
reducer.applyReducePattern( reducer.applyReducePattern(
softmaxOpsToReduce, softmaxOpsToReduce,
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(loc, scalarTensorType, a, b); }, [&](Value a, Value b) { return spatial::SpatVAddOp::create(rewriter, loc, scalarTensorType, a, b); },
/* preprocess = */ /* preprocess = */
[&](Value a) { return rewriter.create<spatial::SpatSumOp>(loc, scalarTensorType, a); }, [&](Value a) { return spatial::SpatSumOp::create(rewriter, loc, scalarTensorType, a); },
[&](Value softmaxDivisor) { [&](Value softmaxDivisor) {
// Signal that this is the compute with the softmax divisor // Signal that this is the compute with the softmax divisor
auto computeOp = cast<spatial::SpatWeightedCompute>(softmaxDivisor.getDefiningOp()->getParentOp()); auto computeOp = cast<spatial::SpatWeightedCompute>(softmaxDivisor.getDefiningOp()->getParentOp());
@@ -345,7 +346,7 @@ LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVector<OpAn
// Broadcast the divisor to all the cores // Broadcast the divisor to all the cores
rewriter.setInsertionPointAfterValue(softmaxDivisor); rewriter.setInsertionPointAfterValue(softmaxDivisor);
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, softmaxChannel, softmaxDivisor); spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, softmaxChannel, softmaxDivisor);
/* /*
* softmaxDividend = onnx.exp (...) * softmaxDividend = onnx.exp (...)
@@ -395,7 +396,7 @@ LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVector<OpAn
} }
else { else {
rewriter.setInsertionPoint(yieldOp); rewriter.setInsertionPoint(yieldOp);
divisor = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(loc, scalarTensorType, softmaxChannel); divisor = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, loc, scalarTensorType, softmaxChannel);
} }
// Walk the chain of operations until we find the ONNXExpOp: this is // Walk the chain of operations until we find the ONNXExpOp: this is
@@ -405,7 +406,7 @@ LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVector<OpAn
Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second)); Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second));
rewriter.setInsertionPoint(yieldOp); rewriter.setInsertionPoint(yieldOp);
Value newOutputTile = rewriter.create<spatial::SpatVSDivOp>(loc, oldOutputTile.getType(), oldOutputTile, divisor); Value newOutputTile = spatial::SpatVSDivOp::create(rewriter, loc, oldOutputTile.getType(), oldOutputTile, divisor);
auto yieldOperandNum = yieldOp->getNumOperands(); auto yieldOperandNum = yieldOp->getNumOperands();
yieldOp->insertOperands(yieldOperandNum, newOutputTile); yieldOp->insertOperands(yieldOperandNum, newOutputTile);

View File

@@ -15,7 +15,7 @@
#include <cmath> #include <cmath>
#include <cstddef> #include <cstddef>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
@@ -111,15 +111,15 @@ Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
// 1. Add a channel before the first computeOp // 1. Add a channel before the first computeOp
rewriter.setInsertionPoint(firstCompute); rewriter.setInsertionPoint(firstCompute);
auto channel = rewriter.create<spatial::SpatChannelNewOp>(loc, channelType); auto channel = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
// 2. Add a sendOp after the first value // 2. Add a sendOp after the first value
rewriter.setInsertionPointAfterValue(firstValue); rewriter.setInsertionPointAfterValue(firstValue);
rewriter.create<spatial::SpatChannelSendOp>(loc, channel, firstValue); spatial::SpatChannelSendOp::create(rewriter, loc, channel, firstValue);
// 3. Add a receiveOp after the second value // 3. Add a receiveOp after the second value
rewriter.setInsertionPointAfterValue(secondValue); rewriter.setInsertionPointAfterValue(secondValue);
auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(loc, secondValue.getType(), channel); auto receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, loc, secondValue.getType(), channel);
// 4. Apply reduction between second value and received value // 4. Apply reduction between second value and received value
rewriter.setInsertionPointAfterValue(receivedValue); rewriter.setInsertionPointAfterValue(receivedValue);
@@ -188,13 +188,14 @@ Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rew
// directly under func.func (i.e. alongside ComputeOps) // directly under func.func (i.e. alongside ComputeOps)
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp()); auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
rewriter.setInsertionPoint(computeOp); rewriter.setInsertionPoint(computeOp);
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc, auto divisorValue = spatial::SpatConstantOp::create(rewriter,
loc,
scalarTensor, scalarTensor,
rewriter.getI64IntegerAttr(divisorNumber), rewriter.getI64IntegerAttr(divisorNumber),
/* should_allocate = */ rewriter.getBoolAttr(true)); /* should_allocate = */ rewriter.getBoolAttr(true));
rewriter.setInsertionPointAfterValue(valueToDivide); rewriter.setInsertionPointAfterValue(valueToDivide);
return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue); return spatial::SpatVSDivOp::create(rewriter, loc, valueToDivide.getType(), valueToDivide, divisorValue);
} }
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp> template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
@@ -257,7 +258,8 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) { if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
Location tileLoc = extractSliceOp.getLoc(); Location tileLoc = extractSliceOp.getLoc();
auto tempComputeOp = rewriter.create<spatial::SpatWeightedCompute>(tileLoc, auto tempComputeOp = spatial::SpatWeightedCompute::create(rewriter,
tileLoc,
extractSliceOp.getResultType(), extractSliceOp.getResultType(),
/* xbarWeights =*/ValueRange(), /* xbarWeights =*/ValueRange(),
extractSliceOp.getResult()); extractSliceOp.getResult());
@@ -267,7 +269,7 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc); auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
rewriter.setInsertionPointToStart(tempComputeOpBlock); rewriter.setInsertionPointToStart(tempComputeOpBlock);
rewriter.create<spatial::SpatYieldOp>(tileLoc, tempComputeOpBlockArg); spatial::SpatYieldOp::create(rewriter, tileLoc, tempComputeOpBlockArg);
rewriter.setInsertionPointAfter(tempComputeOp); rewriter.setInsertionPointAfter(tempComputeOp);
inputTiles[t][x][y] = tempComputeOp.getResult(0); inputTiles[t][x][y] = tempComputeOp.getResult(0);
} }
@@ -356,7 +358,7 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
Value reducedWithinCompute = applyReducePatternNew( Value reducedWithinCompute = applyReducePatternNew(
valuesToPool, valuesToPool,
rewriter, rewriter,
[&](const Value lhs, const Value rhs) { return rewriter.create<ReduceOp>(loc, lhs.getType(), lhs, rhs); }, [&](const Value lhs, const Value rhs) { return ReduceOp::create(rewriter, loc, lhs.getType(), lhs, rhs); },
nullptr, nullptr,
postProcessFn); postProcessFn);
@@ -369,16 +371,16 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
// Create a new channel before the computeOp // Create a new channel before the computeOp
rewriter.setInsertionPoint(computeOpOfReduced); rewriter.setInsertionPoint(computeOpOfReduced);
auto reduceChannel = auto reduceChannel =
rewriter.create<spatial::SpatChannelNewOp>(loc, spatial::SpatChannelType::get(rewriter.getContext())); spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(rewriter.getContext()));
// Send value through the channel // Send value through the channel
rewriter.setInsertionPointAfterValue(reducedWithinCompute); rewriter.setInsertionPointAfterValue(reducedWithinCompute);
rewriter.create<spatial::SpatChannelSendOp>(loc, reduceChannel, reducedWithinCompute); spatial::SpatChannelSendOp::create(rewriter, loc, reduceChannel, reducedWithinCompute);
// Receive after the computeOp // Receive after the computeOp
rewriter.setInsertionPointAfter(computeOpOfReduced); rewriter.setInsertionPointAfter(computeOpOfReduced);
auto receivedValue = auto receivedValue =
rewriter.create<spatial::SpatChannelReceiveOp>(loc, reducedWithinCompute.getType(), reduceChannel); spatial::SpatChannelReceiveOp::create(rewriter, loc, reducedWithinCompute.getType(), reduceChannel);
outputTiles[outTile][outX][outY] = receivedValue; outputTiles[outTile][outX][outY] = receivedValue;
} }

View File

@@ -63,7 +63,8 @@ struct ReduceMeanConversionPattern : public OpConversionPattern<ONNXReduceMeanV1
/*elementType=*/inputTensorType.getElementType()); /*elementType=*/inputTensorType.getElementType());
// Create the ONNXAveragePoolOp. // Create the ONNXAveragePoolOp.
auto averagePool = rewriter.create<ONNXAveragePoolOp>(reduceMean.getLoc(), auto averagePool = ONNXAveragePoolOp::create(rewriter,
reduceMean.getLoc(),
resultType, resultType,
inputTensor, inputTensor,
/*auto_pad=*/"NOTSET", /*auto_pad=*/"NOTSET",

View File

@@ -29,7 +29,7 @@ def matMulToGemmPattern : Pat<
(ONNXMatMulOp:$matmulres $A, $B), (ONNXMatMulOp:$matmulres $A, $B),
( (
ONNXGemmOp $A, $B, ONNXGemmOp $A, $B,
/* C = */ (NativeCodeCall<"$_builder.create<tensor::EmptyOp>($_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">), /* C = */ (NativeCodeCall<"tensor::EmptyOp::create($_builder, $_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">), /* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">), /* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">), /* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),

View File

@@ -47,7 +47,7 @@ SmallVector<Value> sliceTensor(
if (i == numSlices - 1 && lastSliceSize != 0) if (i == numSlices - 1 && lastSliceSize != 0)
sizes[axis] = rewriter.getIndexAttr(lastSliceSize); sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
Value slice = rewriter.create<tensor::ExtractSliceOp>(loc, tensorToSlice, offsets, sizes, strides); Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
slices.push_back(slice); slices.push_back(slice);
} }
@@ -100,11 +100,11 @@ broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewr
int64_t shape[2] = {1, length}; int64_t shape[2] = {1, length};
Type type = oldType.cloneWith(ArrayRef(shape), elementType); Type type = oldType.cloneWith(ArrayRef(shape), elementType);
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult(); auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
SmallVector<Value> index(oldType.getRank(), zero); SmallVector<Value> index(oldType.getRank(), zero);
auto elementValue = rewriter.create<tensor::ExtractOp>(loc, scalarToBroadcast, index).getResult(); auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
return rewriter.create<tensor::SplatOp>(loc, type, elementValue); return tensor::SplatOp::create(rewriter, loc, type, elementValue);
} }
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) { Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
@@ -122,7 +122,7 @@ Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
Value a = (*currTensors)[i]; Value a = (*currTensors)[i];
Value b = (*currTensors)[i + 1]; Value b = (*currTensors)[i + 1];
rewriter.setInsertionPointAfterValue(b); rewriter.setInsertionPointAfterValue(b);
auto addedValue = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b); auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
nextTensors->push_back(addedValue); nextTensors->push_back(addedValue);
} }
if (currTensors->size() % 2 == 1) if (currTensors->size() % 2 == 1)
@@ -137,10 +137,10 @@ Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input) { Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input) {
switch (mapOp) { switch (mapOp) {
case MapOperations::None: assert(false && "Invalid map operation during map operation creation."); case MapOperations::None: assert(false && "Invalid map operation during map operation creation.");
case MapOperations::ONNXSoftmaxOp: return rewriter.create<ONNXSoftmaxOp>(input.getLoc(), input.getType(), input); case MapOperations::ONNXSoftmaxOp: return ONNXSoftmaxOp::create(rewriter, input.getLoc(), input.getType(), input);
case MapOperations::ONNXReluOp: return rewriter.create<ONNXReluOp>(input.getLoc(), input.getType(), input); case MapOperations::ONNXReluOp: return ONNXReluOp::create(rewriter, input.getLoc(), input.getType(), input);
case MapOperations::ONNXLeakyReluOp: return rewriter.create<ONNXLeakyReluOp>(input.getLoc(), input.getType(), input); case MapOperations::ONNXLeakyReluOp: return ONNXLeakyReluOp::create(rewriter, input.getLoc(), input.getType(), input);
case MapOperations::ONNXExpOp: return rewriter.create<ONNXExpOp>(input.getLoc(), input.getType(), input); case MapOperations::ONNXExpOp: return ONNXExpOp::create(rewriter, input.getLoc(), input.getType(), input);
} }
} }
@@ -201,7 +201,7 @@ void tileImageTensorByChannel(Value imageTensor,
offsets[2] = rewriter.getIndexAttr(x); offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y); offsets[3] = rewriter.getIndexAttr(y);
tiles[i][x][y] = rewriter.create<tensor::ExtractSliceOp>(loc, imageTensor, offsets, sizes, strides); tiles[i][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, imageTensor, offsets, sizes, strides);
} }
} }
} }
@@ -225,7 +225,7 @@ Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTile
for (size_t outTile = 0; outTile < outputTileCount; outTile++) for (size_t outTile = 0; outTile < outputTileCount; outTile++)
tilesToConcat.push_back(outputTiles[outTile][outX][outY]); tilesToConcat.push_back(outputTiles[outTile][outX][outY]);
return rewriter.create<spatial::SpatImgConcatOp>(loc, outputType, tilesToConcat); return spatial::SpatImgConcatOp::create(rewriter, loc, outputType, tilesToConcat);
} }
LogicalResult LogicalResult
@@ -271,7 +271,7 @@ Value createExtractSliceImg(Value valToSlice,
offsets[2] = rewriter.getIndexAttr(x); offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y); offsets[3] = rewriter.getIndexAttr(y);
return rewriter.create<tensor::ExtractSliceOp>(valToSlice.getLoc(), valToSlice, offsets, sizes, strides); return tensor::ExtractSliceOp::create(rewriter, valToSlice.getLoc(), valToSlice, offsets, sizes, strides);
} }
Value indexImgValue(Value v, Value indexImgValue(Value v,
@@ -384,7 +384,7 @@ void resolveInputTensorTilesBlockArg(Value wholeInputTensor,
offsets[2] = rewriter.getIndexAttr(x); offsets[2] = rewriter.getIndexAttr(x);
offsets[3] = rewriter.getIndexAttr(y); offsets[3] = rewriter.getIndexAttr(y);
inputTiles[t][x][y] = rewriter.create<tensor::ExtractSliceOp>(loc, wholeInputTensor, offsets, sizes, strides); inputTiles[t][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, wholeInputTensor, offsets, sizes, strides);
} }
} }
} }
@@ -452,7 +452,7 @@ LogicalResult handleFlattenLikeOp(SmallVector<SmallVector<Value>>& inputTiles,
SmallVector<int64_t> newShapeVals = {curTileShape.getDimSize(0), curTileShape.getDimSize(1)}; SmallVector<int64_t> newShapeVals = {curTileShape.getDimSize(0), curTileShape.getDimSize(1)};
auto shapeType = RankedTensorType::get({static_cast<int64_t>(newShapeVals.size())}, rewriter.getI64Type()); auto shapeType = RankedTensorType::get({static_cast<int64_t>(newShapeVals.size())}, rewriter.getI64Type());
Value shapeTensor = Value shapeTensor =
rewriter.create<arith::ConstantOp>(reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals)); arith::ConstantOp::create(rewriter, reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals));
auto reshapedType = RankedTensorType::get(newShapeVals, curTileShape.getElementType()); auto reshapedType = RankedTensorType::get(newShapeVals, curTileShape.getElementType());
auto reshapedCurTile = tosa::ReshapeOp::create(rewriter, reshapeInput.getLoc(), reshapedType, curTile, shapeTensor); auto reshapedCurTile = tosa::ReshapeOp::create(rewriter, reshapeInput.getLoc(), reshapedType, curTile, shapeTensor);

View File

@@ -9,7 +9,7 @@
#include <fstream> #include <fstream>
#include "Common/PIMCommon.hpp" #include "Common/PimCommon.hpp"
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp" #include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"

View File

@@ -272,8 +272,8 @@ void SpatialReducer::updateResultsOfCompute(mlir::Operation* computeOp) {
// Create a new ComputeOp with the new result type, but same operands // Create a new ComputeOp with the new result type, but same operands
rewriter.setInsertionPoint(oldComputeOp); rewriter.setInsertionPoint(oldComputeOp);
auto newComputeOp = rewriter.create<spatial::SpatWeightedCompute>( auto newComputeOp = spatial::SpatWeightedCompute::create(
oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs()); rewriter, oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs());
newComputeOp.getBody().takeBody(oldComputeOp.getBody()); newComputeOp.getBody().takeBody(oldComputeOp.getBody());
@@ -333,14 +333,14 @@ OpAndResNum SpatialReducer::applyAddMapReduction(llvm::SmallVector<ComputeAndRes
postprocessing = [&](const mlir::Value a) { postprocessing = [&](const mlir::Value a) {
mlir::Value mapOperand = a; mlir::Value mapOperand = a;
if (biasTile) if (biasTile)
mapOperand = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, biasTile); mapOperand = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, biasTile);
return createMapOperation(rewriter, mapOp, mapOperand); return createMapOperation(rewriter, mapOp, mapOperand);
}; };
} }
return this->applyReducePattern( return this->applyReducePattern(
computeOps, computeOps,
[&](mlir::Value a, mlir::Value b) { return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b); }, [&](mlir::Value a, mlir::Value b) { return spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b); },
/* preprocess = */ nullptr, /* preprocess = */ nullptr,
postprocessing); postprocessing);
} }

View File

@@ -5,7 +5,7 @@ add_onnx_mlir_library(OMSpatialToGraphviz
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
OMCompilerOptions OMCompilerOptions
OMPIMCommon OMPimCommon
OMONNXOps OMONNXOps
SpatialOps SpatialOps

View File

@@ -10,7 +10,7 @@
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/Format.h" #include "llvm/Support/Format.h"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"

View File

@@ -1,20 +0,0 @@
set(LLVM_TARGET_DEFINITIONS SpatialToPIM.td)
mlir_tablegen(SpatialToPIM.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(SpatialToPIMIncGen)
add_onnx_mlir_library(OMSpatialToPIM
SpatialToPIMPass.cpp
SpatialToPIMCommon.cpp
DEPENDS
SpatialToPIMIncGen
LINK_LIBS PUBLIC
OMCompilerOptions
OMPIMCommon
SpatialOps
PimOps
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH}
)

View File

@@ -0,0 +1,20 @@
set(LLVM_TARGET_DEFINITIONS SpatialToPim.td)
mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(SpatialToPimIncGen)
add_onnx_mlir_library(OMSpatialToPim
SpatialToPimPass.cpp
SpatialToPimCommon.cpp
DEPENDS
SpatialToPimIncGen
LINK_LIBS PUBLIC
OMCompilerOptions
OMPimCommon
SpatialOps
PimOps
ACCEL_INCLUDE_DIRS PRIVATE
${PIM_INCLUDE_PATH}
)

View File

@@ -5,7 +5,7 @@
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include "SpatialToPIMCommon.hpp" #include "SpatialToPimCommon.hpp"
using namespace llvm; using namespace llvm;
using namespace mlir; using namespace mlir;

View File

@@ -18,8 +18,8 @@
#include <utility> #include <utility>
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp" #include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PimPasses.hpp" #include "src/Accelerators/PIM/Pass/PimPasses.hpp"
@@ -33,15 +33,15 @@ namespace onnx_mlir {
namespace { namespace {
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc" #include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
struct SpatialToPIMPass : PassWrapper<SpatialToPIMPass, OperationPass<ModuleOp>> { struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPIMPass) MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
StringRef getArgument() const override { return "convert-spatial-to-pim"; } StringRef getArgument() const override { return "convert-spatial-to-pim"; }
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; } StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
SpatialToPIMPass() = default; SpatialToPimPass() = default;
SpatialToPIMPass(const SpatialToPIMPass& pass) {} SpatialToPimPass(const SpatialToPimPass& pass) {}
void runOnOperation() final; void runOnOperation() final;
@@ -76,7 +76,7 @@ private:
} // namespace } // namespace
void SpatialToPIMPass::runOnOperation() { void SpatialToPimPass::runOnOperation() {
coreId = 1; coreId = 1;
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
MLIRContext* ctx = moduleOp.getContext(); MLIRContext* ctx = moduleOp.getContext();
@@ -132,7 +132,7 @@ void SpatialToPIMPass::runOnOperation() {
dumpModule(moduleOp, "pim"); dumpModule(moduleOp, "pim");
} }
void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) { void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
Location loc = computeOp->getLoc(); Location loc = computeOp->getLoc();
auto& block = computeOp.getRegion().front(); auto& block = computeOp.getRegion().front();
@@ -180,7 +180,8 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
// Store to global memory // Store to global memory
Value outputTensor = outputTensors[resultIndexInReturn]; Value outputTensor = outputTensors[resultIndexInReturn];
rewriter.setInsertionPointAfterValue(yieldValue); rewriter.setInsertionPointAfterValue(yieldValue);
rewriter.create<PimMemCopyDevToHostOp>(loc, PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(), outputTensor.getType(),
outputTensor, outputTensor,
yieldValue, yieldValue,
@@ -211,7 +212,7 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
// Store to global memory // Store to global memory
Value outputTensor = outputTensors[concatIndexInReturn]; Value outputTensor = outputTensors[concatIndexInReturn];
rewriter.setInsertionPointAfterValue(yieldValue); rewriter.setInsertionPointAfterValue(yieldValue);
rewriter.create<PimMemCopyDevToHostOp>( PimMemCopyDevToHostOp::create(rewriter,
loc, loc,
outputTensor.getType(), outputTensor.getType(),
outputTensor, outputTensor,
@@ -230,7 +231,7 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
// 1. Create a new ChannelOp // 1. Create a new ChannelOp
rewriter.setInsertionPoint(computeOp); rewriter.setInsertionPoint(computeOp);
auto channelType = spatial::SpatChannelType::get(computeOp.getContext()); auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
auto channelOp = rewriter.create<spatial::SpatChannelNewOp>(loc, channelType); auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
// 2. Receive value through the channel // 2. Receive value through the channel
// If this result is used by more than one user, then use a "Broadcast" // If this result is used by more than one user, then use a "Broadcast"
@@ -244,9 +245,9 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
// 3. Send the value through the channel // 3. Send the value through the channel
rewriter.setInsertionPointAfterValue(yieldValue); rewriter.setInsertionPointAfterValue(yieldValue);
if (useBroadcastOp) if (useBroadcastOp)
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, channelOp, yieldValue); spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, channelOp, yieldValue);
else else
rewriter.create<spatial::SpatChannelSendOp>(loc, channelOp, yieldValue); spatial::SpatChannelSendOp::create(rewriter, loc, channelOp, yieldValue);
} }
// Use `HaltOp` instead of `YieldOp` // Use `HaltOp` instead of `YieldOp`
@@ -255,17 +256,17 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
// Replace `spat.compute` with `pim.core` // Replace `spat.compute` with `pim.core`
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
auto coreOp = rewriter.create<PimCoreOp>(loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++)); auto coreOp = PimCoreOp::create(rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++));
auto& coreOpBlocks = coreOp.getBody().getBlocks(); auto& coreOpBlocks = coreOp.getBody().getBlocks();
block.eraseArguments(0, block.getNumArguments()); block.eraseArguments(0, block.getNumArguments());
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks()); coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
Block* tempComputeBlock = new Block(); Block* tempComputeBlock = new Block();
computeOp.getBody().push_back(tempComputeBlock); computeOp.getBody().push_back(tempComputeBlock);
rewriter.setInsertionPointToEnd(tempComputeBlock); rewriter.setInsertionPointToEnd(tempComputeBlock);
rewriter.create<PimHaltOp>(computeOp.getLoc()); PimHaltOp::create(rewriter, computeOp.getLoc());
} }
void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) { void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void { auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
auto* definingOp = value.getDefiningOp(); auto* definingOp = value.getDefiningOp();
if (!definingOp) if (!definingOp)
@@ -302,14 +303,14 @@ void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr}; SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr}; SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
rewriter.setInsertionPointAfter(vmmOp); rewriter.setInsertionPointAfter(vmmOp);
auto sliceOp = rewriter.create<tensor::ExtractSliceOp>(vmmOp.getLoc(), resultTensor, offsets, sizes, strides); auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp}; SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions); resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
} }
}); });
} }
void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) { void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
outputTensors.reserve(returnOp->getNumOperands()); outputTensors.reserve(returnOp->getNumOperands());
rewriter.setInsertionPointToStart(returnOp->getBlock()); rewriter.setInsertionPointToStart(returnOp->getBlock());
for (auto returnValue : returnOp->getOperands()) { for (auto returnValue : returnOp->getOperands()) {
@@ -326,7 +327,7 @@ void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
} }
} }
void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
Location loc = funcOp.getLoc(); Location loc = funcOp.getLoc();
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) { auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
@@ -335,9 +336,10 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace)); rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace));
auto deviceTensor = rewriter.create<tensor::EmptyOp>(loc, tensorType.getShape(), elementType); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
auto memCopyHostToDevOp = rewriter.create<PimMemCopyHostToDevOp>( auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
rewriter,
loc, loc,
tensorType, tensorType,
deviceTensor, deviceTensor,
@@ -362,7 +364,8 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
Block& block = funcOp.getBody().front(); Block& block = funcOp.getBody().front();
rewriter.setInsertionPoint(&block.front()); rewriter.setInsertionPoint(&block.front());
auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(loc, tensorArgType, memRefArg, rewriter.getUnitAttr()); auto toTensorOp =
bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
inputTensors.push_back(toTensorOp); inputTensors.push_back(toTensorOp);
tensorArg.replaceAllUsesWith(toTensorOp); tensorArg.replaceAllUsesWith(toTensorOp);
@@ -415,7 +418,7 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
rewriter.eraseOp(sliceOp); rewriter.eraseOp(sliceOp);
} }
void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp, void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
unsigned int argIndex, unsigned int argIndex,
spatial::SpatChannelNewOp& channel, spatial::SpatChannelNewOp& channel,
Type& tensorType, Type& tensorType,
@@ -431,14 +434,14 @@ void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg)); rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
Value receivedValue; Value receivedValue;
if (useBroadcastOp) if (useBroadcastOp)
receivedValue = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(computeOp.getLoc(), tensorType, channel); receivedValue = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel);
else else
receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(computeOp.getLoc(), tensorType, channel); receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel);
blockArg.replaceAllUsesWith(receivedValue); blockArg.replaceAllUsesWith(receivedValue);
} }
void SpatialToPIMPass::addReceiveOps(Value& channelSourceOp, void SpatialToPimPass::addReceiveOps(Value& channelSourceOp,
spatial::SpatChannelNewOp& channel, spatial::SpatChannelNewOp& channel,
Type& channelTensorType, Type& channelTensorType,
bool& useBroadcastOp, bool& useBroadcastOp,
@@ -495,7 +498,7 @@ void SpatialToPIMPass::addReceiveOps(Value& channelSourceOp,
} }
} }
void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) { void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
for (auto it : llvm::enumerate(returnOp.getOperands())) { for (auto it : llvm::enumerate(returnOp.getOperands())) {
Operation* returnOperand = it.value().getDefiningOp(); Operation* returnOperand = it.value().getDefiningOp();
@@ -514,7 +517,7 @@ void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
} }
} }
void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) { void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
auto channel = cast<spatial::SpatChannelNewOp>(receiveOp.getChannel().getDefiningOp()); auto channel = cast<spatial::SpatChannelNewOp>(receiveOp.getChannel().getDefiningOp());
@@ -543,6 +546,6 @@ void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
} }
} }
std::unique_ptr<Pass> createSpatialToPIMPass() { return std::make_unique<SpatialToPIMPass>(); } std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
} // namespace onnx_mlir } // namespace onnx_mlir

View File

@@ -13,7 +13,7 @@ add_onnx_mlir_library(OMPimBufferization
PimBufferizationIncGen PimBufferizationIncGen
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
OMPIMCommon OMPimCommon
PimOps PimOps
ACCEL_INCLUDE_DIRS PRIVATE ACCEL_INCLUDE_DIRS PRIVATE

View File

@@ -5,7 +5,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "Common/PIMCommon.hpp" #include "Common/PimCommon.hpp"
#include "Compiler/PimCodeGen.hpp" #include "Compiler/PimCodeGen.hpp"
#include "Dialect/Pim/PimOps.hpp" #include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp" #include "Dialect/Pim/Transforms/Bufferization/Common.hpp"

View File

@@ -17,7 +17,7 @@
#include <cstdint> #include <cstdint>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
@@ -34,7 +34,7 @@ memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase&
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType()); auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
// Alloc an output memref // Alloc an output memref
return rewriter.create<memref::AllocOp>(loc, memrefResultType); return memref::AllocOp::create(rewriter, loc, memrefResultType);
} }
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id"); const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
@@ -134,7 +134,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
memrefOperands.push_back(outputTensor); memrefOperands.push_back(outputTensor);
Value newValue = rewriter.create<ToTy>(op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes(); Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
replaceOpWithBufferizedValues(rewriter, op, newValue); replaceOpWithBufferizedValues(rewriter, op, newValue);
@@ -169,10 +169,12 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod
// Alloc an output memref // Alloc an output memref
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter); Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
Value newValue = Value newValue = ToTy::create(rewriter,
rewriter op->getLoc(),
.create<ToTy>( outputTensor.getType(),
op->getLoc(), outputTensor.getType(), cast<OpTy>(op).getWeightIndexAttr(), memrefOperand, outputTensor) cast<OpTy>(op).getWeightIndexAttr(),
memrefOperand,
outputTensor)
.getOutRes(); .getOutRes();
replaceOpWithBufferizedValues(rewriter, op, newValue); replaceOpWithBufferizedValues(rewriter, op, newValue);
@@ -213,8 +215,8 @@ struct ChannelReceiveOpInterface
if (failed(srcCoreId)) if (failed(srcCoreId))
return failure(); return failure();
Value newValue = rewriter Value newValue = pim::PimReceiveOp::create(rewriter,
.create<pim::PimReceiveOp>(op->getLoc(), op->getLoc(),
outputTensor.getType(), outputTensor.getType(),
outputTensor, outputTensor,
rewriter.getI32IntegerAttr(numElements * elementSize), rewriter.getI32IntegerAttr(numElements * elementSize),
@@ -324,7 +326,8 @@ struct ChannelBroadcastReceiveOpInterface
} }
rewriter.setInsertionPoint(op); rewriter.setInsertionPoint(op);
auto memCopyHostToDevOp = rewriter.create<pim::PimMemCopyHostToDevOp>(op->getLoc(), auto memCopyHostToDevOp = pim::PimMemCopyHostToDevOp::create(rewriter,
op->getLoc(),
outputTensor.getType(), outputTensor.getType(),
outputTensor, outputTensor,
bufferAllocation, bufferAllocation,
@@ -395,7 +398,8 @@ struct ChannelBroadcastSendOpInterface
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8; auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
rewriter.setInsertionPoint(op); rewriter.setInsertionPoint(op);
rewriter.create<pim::PimMemCopyDevToHostOp>(op->getLoc(), pim::PimMemCopyDevToHostOp::create(rewriter,
op->getLoc(),
bufferAllocation.getType(), bufferAllocation.getType(),
bufferAllocation, bufferAllocation,
srcMemRef, srcMemRef,
@@ -481,7 +485,8 @@ struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFil
auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr(); auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr();
auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr(); auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr();
Value bufferized = rewriter.create<pim::PimApplyFiltersOp>(op->getLoc(), Value bufferized = pim::PimApplyFiltersOp::create(rewriter,
op->getLoc(),
outputTensor.getType(), outputTensor.getType(),
weightIndices, weightIndices,
xKernelPositions, xKernelPositions,

View File

@@ -1,6 +1,6 @@
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "Common/PIMCommon.hpp" #include "Common/PimCommon.hpp"
#include "Compiler/PimCodeGen.hpp" #include "Compiler/PimCodeGen.hpp"
using namespace mlir; using namespace mlir;

View File

@@ -12,7 +12,7 @@
#include <memory> #include <memory>
#include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;

View File

@@ -5,7 +5,7 @@
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/PIMCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"

View File

@@ -11,7 +11,7 @@ std::unique_ptr<mlir::Pass> createONNXToSpatialPass();
std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass(); std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
std::unique_ptr<mlir::Pass> createSpatialToPIMPass(); std::unique_ptr<mlir::Pass> createSpatialToPimPass();
std::unique_ptr<mlir::Pass> createBufferizePimPass(); std::unique_ptr<mlir::Pass> createBufferizePimPass();

View File

@@ -71,7 +71,7 @@ void PimAccelerator::registerPasses(int optLevel) const {
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n"); LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
registerPass(createONNXToSpatialPass); registerPass(createONNXToSpatialPass);
registerPass(createSpatialToGraphvizPass); registerPass(createSpatialToGraphvizPass);
registerPass(createSpatialToPIMPass); registerPass(createSpatialToPimPass);
registerPass(createBufferizePimPass); registerPass(createBufferizePimPass);
registerPass(createPimFoldHostConstantsPass); registerPass(createPimFoldHostConstantsPass);
registerPass(createPimHostVerificationPass); registerPass(createPimHostVerificationPass);