replace deprecated "rewriter.create()" pattern
refactor PIM to Pim everywhere except for the accelerator name
This commit is contained in:
@@ -43,7 +43,7 @@ add_onnx_mlir_library(OMPIMAccel
|
|||||||
PimOps
|
PimOps
|
||||||
OMONNXToSpatial
|
OMONNXToSpatial
|
||||||
OMSpatialToGraphviz
|
OMSpatialToGraphviz
|
||||||
OMSpatialToPIM
|
OMSpatialToPim
|
||||||
OMPIMCommon
|
OMPimCommon
|
||||||
MLIRTensorInferTypeOpInterfaceImpl
|
MLIRTensorInferTypeOpInterfaceImpl
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
add_onnx_mlir_library(OMPIMCommon
|
add_onnx_mlir_library(OMPimCommon
|
||||||
PIMCommon.cpp
|
PimCommon.cpp
|
||||||
|
|
||||||
EXCLUDE_FROM_OM_LIBS
|
EXCLUDE_FROM_OM_LIBS
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,10 @@
|
|||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -37,8 +37,7 @@ FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
|
|||||||
|
|
||||||
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
|
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
|
||||||
if (entryPoints.size() > 1) {
|
if (entryPoints.size() > 1) {
|
||||||
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ")
|
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
|
||||||
<< entryPoints.size();
|
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
if (!entryPoints.empty()) {
|
if (!entryPoints.empty()) {
|
||||||
@@ -61,10 +60,9 @@ FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
|
|||||||
return mainGraphFunc;
|
return mainGraphFunc;
|
||||||
|
|
||||||
SmallVector<func::FuncOp> nonExternalFuncs;
|
SmallVector<func::FuncOp> nonExternalFuncs;
|
||||||
for (auto funcOp : moduleOp.getOps<func::FuncOp>()) {
|
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
|
||||||
if (!funcOp.isExternal())
|
if (!funcOp.isExternal())
|
||||||
nonExternalFuncs.push_back(funcOp);
|
nonExternalFuncs.push_back(funcOp);
|
||||||
}
|
|
||||||
if (nonExternalFuncs.size() == 1)
|
if (nonExternalFuncs.size() == 1)
|
||||||
return nonExternalFuncs.front();
|
return nonExternalFuncs.front();
|
||||||
|
|
||||||
@@ -72,11 +70,11 @@ FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
|
|||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PIM_WEIGHT_ALWAYS_ATTR_NAME) != nullptr; }
|
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
||||||
|
|
||||||
void markWeightAlways(Operation* op) {
|
void markWeightAlways(Operation* op) {
|
||||||
assert(op && "expected valid op");
|
assert(op && "expected valid op");
|
||||||
op->setAttr(PIM_WEIGHT_ALWAYS_ATTR_NAME, UnitAttr::get(op->getContext()));
|
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
|
||||||
}
|
}
|
||||||
|
|
||||||
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/IR/Operation.h"
|
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/Operation.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/IR/Value.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
@@ -10,8 +10,8 @@
|
|||||||
|
|
||||||
#include "src/Compiler/CompilerOptions.hpp"
|
#include "src/Compiler/CompilerOptions.hpp"
|
||||||
|
|
||||||
const llvm::StringRef PIM_CONSTANT_SHOULD_ALLOCATE_ATTR_NAME = "pim.constant.should_allocate";
|
const llvm::StringRef PimConstantShouldAllocateAttrName = "pim.constant.should_allocate";
|
||||||
inline constexpr llvm::StringRef PIM_WEIGHT_ALWAYS_ATTR_NAME = "weightAlways";
|
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
@@ -13,9 +13,9 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "Common/PIMCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||||
#include "Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
|
#include "Conversion/SpatialToPim/SpatialToPimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (pimEmissionTarget >= EmitPim) {
|
if (pimEmissionTarget >= EmitPim) {
|
||||||
pm.addPass(createSpatialToPIMPass());
|
pm.addPass(createSpatialToPimPass());
|
||||||
// pm.addPass(createCountInstructionPass());
|
// pm.addPass(createCountInstructionPass());
|
||||||
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
add_subdirectory(ONNXToSpatial)
|
add_subdirectory(ONNXToSpatial)
|
||||||
add_subdirectory(SpatialToGraphviz)
|
add_subdirectory(SpatialToGraphviz)
|
||||||
add_subdirectory(SpatialToPIM)
|
add_subdirectory(SpatialToPim)
|
||||||
@@ -23,7 +23,7 @@ add_onnx_mlir_library(OMONNXToSpatial
|
|||||||
OMPimCompilerOptions
|
OMPimCompilerOptions
|
||||||
OMONNXOps
|
OMONNXOps
|
||||||
SpatialOps
|
SpatialOps
|
||||||
OMPIMCommon
|
OMPimCommon
|
||||||
|
|
||||||
ACCEL_INCLUDE_DIRS PRIVATE
|
ACCEL_INCLUDE_DIRS PRIVATE
|
||||||
${PIM_INCLUDE_PATH}
|
${PIM_INCLUDE_PATH}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
@@ -91,7 +91,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
auto aSliceType = RankedTensorType::get({1, aType.getDimSize(1)}, aType.getElementType());
|
||||||
auto aSlice = rewriter.create<tensor::ExtractSliceOp>(loc, aSliceType, a, offsets, sizes, strides).getResult();
|
auto aSlice = tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, strides).getResult();
|
||||||
|
|
||||||
Value cSlice = c;
|
Value cSlice = c;
|
||||||
if (hasC) {
|
if (hasC) {
|
||||||
@@ -100,13 +100,14 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(cType.getDimSize(1))};
|
||||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||||
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
auto cSliceType = RankedTensorType::get({1, cType.getDimSize(1)}, cType.getElementType());
|
||||||
cSlice = rewriter.create<tensor::ExtractSliceOp>(loc, cSliceType, c, offsets, sizes, strides).getResult();
|
cSlice = tensor::ExtractSliceOp::create(rewriter, loc, cSliceType, c, offsets, sizes, strides).getResult();
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
|
assert("C should be a vector" && isVectorShape(getTensorShape(c)));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto gemvOp = rewriter.create<ONNXGemmOp>(loc,
|
auto gemvOp = ONNXGemmOp::create(rewriter,
|
||||||
|
loc,
|
||||||
outRowType,
|
outRowType,
|
||||||
aSlice,
|
aSlice,
|
||||||
b,
|
b,
|
||||||
@@ -119,7 +120,7 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto concatComputeOp =
|
auto concatComputeOp =
|
||||||
rewriter.create<spatial::SpatWeightedCompute>(loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
|
spatial::SpatWeightedCompute::create(rewriter, loc, gemmOp.getType(), SmallVector<Value>(), gemvOps);
|
||||||
|
|
||||||
auto* concatBlock = new Block();
|
auto* concatBlock = new Block();
|
||||||
for (auto gemvOp : gemvOps)
|
for (auto gemvOp : gemvOps)
|
||||||
@@ -128,8 +129,8 @@ LogicalResult GemmToManyGemv::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
rewriter.setInsertionPointToStart(concatBlock);
|
rewriter.setInsertionPointToStart(concatBlock);
|
||||||
|
|
||||||
auto blockArgs = concatBlock->getArguments();
|
auto blockArgs = concatBlock->getArguments();
|
||||||
auto concatOp = rewriter.create<tensor::ConcatOp>(loc, /*axis=*/0, blockArgs);
|
auto concatOp = tensor::ConcatOp::create(rewriter, loc, /*axis=*/0, blockArgs);
|
||||||
rewriter.create<spatial::SpatYieldOp>(loc, concatOp.getResult());
|
spatial::SpatYieldOp::create(rewriter, loc, concatOp.getResult());
|
||||||
|
|
||||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||||
return success();
|
return success();
|
||||||
@@ -170,25 +171,25 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
if (transA) {
|
if (transA) {
|
||||||
auto aShape = aType.getShape();
|
auto aShape = aType.getShape();
|
||||||
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
auto transposedType = aType.cloneWith(ArrayRef({aShape[1], aShape[0]}), aType.getElementType());
|
||||||
a = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
|
a = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, a, rewriter.getI64ArrayAttr({1, 0}));
|
||||||
}
|
}
|
||||||
if (transB) {
|
if (transB) {
|
||||||
auto bShape = bType.getShape();
|
auto bShape = bType.getShape();
|
||||||
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
auto transposedType = bType.cloneWith(ArrayRef({bShape[1], bShape[0]}), bType.getElementType());
|
||||||
b = rewriter.create<ONNXTransposeOp>(gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
b = ONNXTransposeOp::create(rewriter, gemmLoc, transposedType, b, rewriter.getI64ArrayAttr({1, 0}));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (alpha != 1.0f) {
|
if (alpha != 1.0f) {
|
||||||
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
|
auto alphaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(a.getType()).getElementType());
|
||||||
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
|
auto alphaTensorValue = DenseFPElementsAttr::get(alphaTensorType, {alpha});
|
||||||
auto alphaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, alphaTensorType, alphaTensorValue);
|
auto alphaTensor = arith::ConstantOp::create(rewriter, gemmLoc, alphaTensorType, alphaTensorValue);
|
||||||
a = rewriter.create<spatial::SpatVMulOp>(gemmLoc, a.getType(), a, alphaTensor);
|
a = spatial::SpatVMulOp::create(rewriter, gemmLoc, a.getType(), a, alphaTensor);
|
||||||
}
|
}
|
||||||
if (hasC && beta != 1.0f) {
|
if (hasC && beta != 1.0f) {
|
||||||
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
|
auto betaTensorType = RankedTensorType::get({1, 1}, cast<RankedTensorType>(c.getType()).getElementType());
|
||||||
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
|
auto betaTensorValue = DenseFPElementsAttr::get(betaTensorType, {beta});
|
||||||
auto betaTensor = rewriter.create<arith::ConstantOp>(gemmLoc, betaTensorType, betaTensorValue);
|
auto betaTensor = arith::ConstantOp::create(rewriter, gemmLoc, betaTensorType, betaTensorValue);
|
||||||
c = rewriter.create<spatial::SpatVMulOp>(gemmLoc, c.getType(), c, betaTensor);
|
c = spatial::SpatVMulOp::create(rewriter, gemmLoc, c.getType(), c, betaTensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
auto [aNumHSlices, aLastHSliceSize] = ceilIntegerDivideWithRemainder(aType.getDimSize(1), crossbarSize.getValue());
|
||||||
@@ -235,7 +236,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
weights.push_back(bTiles[outSliceId][coreId][aSliceId]);
|
||||||
|
|
||||||
auto computeOp =
|
auto computeOp =
|
||||||
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, weights, aHSlices[coreId]);
|
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, weights, aHSlices[coreId]);
|
||||||
|
|
||||||
auto* computeBlock = new Block();
|
auto* computeBlock = new Block();
|
||||||
for (auto aHSlice : aHSlices[coreId])
|
for (auto aHSlice : aHSlices[coreId])
|
||||||
@@ -248,11 +249,11 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
vmmOutputs.reserve(computeArgs.size());
|
vmmOutputs.reserve(computeArgs.size());
|
||||||
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++)
|
for (size_t aHSliceId = 0; aHSliceId < aNumHSlices; aHSliceId++)
|
||||||
vmmOutputs.push_back(
|
vmmOutputs.push_back(
|
||||||
rewriter.create<spatial::SpatWeightedVMMOp>(gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId]));
|
spatial::SpatWeightedVMMOp::create(rewriter, gemmLoc, currOutHSliceType, aHSliceId, computeArgs[aHSliceId]));
|
||||||
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
|
assert(!vmmOutputs.empty() && "vmmOutputs must be non-empty");
|
||||||
|
|
||||||
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
Value partialVmmSum = sumTensors(vmmOutputs, rewriter);
|
||||||
rewriter.create<spatial::SpatYieldOp>(gemmLoc, partialVmmSum);
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, partialVmmSum);
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
|
|
||||||
partialResults.push_back(computeOp.getResult(0));
|
partialResults.push_back(computeOp.getResult(0));
|
||||||
@@ -264,7 +265,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto reduceComputeOp =
|
auto reduceComputeOp =
|
||||||
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults);
|
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, currOutHSliceType, SmallVector<Value>(), partialResults);
|
||||||
|
|
||||||
auto* reduceBlock = new Block();
|
auto* reduceBlock = new Block();
|
||||||
for (auto partialResult : partialResults)
|
for (auto partialResult : partialResults)
|
||||||
@@ -274,14 +275,14 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
|
|
||||||
auto blockArgs = reduceBlock->getArguments();
|
auto blockArgs = reduceBlock->getArguments();
|
||||||
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
|
Value outHSlice = sumTensors({blockArgs.begin(), blockArgs.end()}, rewriter);
|
||||||
rewriter.create<spatial::SpatYieldOp>(gemmLoc, outHSlice);
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, outHSlice);
|
||||||
rewriter.setInsertionPointAfter(reduceComputeOp);
|
rewriter.setInsertionPointAfter(reduceComputeOp);
|
||||||
|
|
||||||
outHSlices.push_back(reduceComputeOp.getResult(0));
|
outHSlices.push_back(reduceComputeOp.getResult(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto concatComputeOp =
|
auto concatComputeOp =
|
||||||
rewriter.create<spatial::SpatWeightedCompute>(gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
|
spatial::SpatWeightedCompute::create(rewriter, gemmLoc, gemmOp.getType(), SmallVector<Value>(), outHSlices);
|
||||||
|
|
||||||
auto* concatBlock = new Block();
|
auto* concatBlock = new Block();
|
||||||
for (auto outHSlice : outHSlices)
|
for (auto outHSlice : outHSlices)
|
||||||
@@ -290,8 +291,8 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
|||||||
rewriter.setInsertionPointToStart(concatBlock);
|
rewriter.setInsertionPointToStart(concatBlock);
|
||||||
|
|
||||||
auto blockArgs = concatBlock->getArguments();
|
auto blockArgs = concatBlock->getArguments();
|
||||||
auto concatOp = rewriter.create<tensor::ConcatOp>(gemmLoc, /*axis=*/1, blockArgs);
|
auto concatOp = tensor::ConcatOp::create(rewriter, gemmLoc, /*axis=*/1, blockArgs);
|
||||||
rewriter.create<spatial::SpatYieldOp>(gemmLoc, concatOp.getResult());
|
spatial::SpatYieldOp::create(rewriter, gemmLoc, concatOp.getResult());
|
||||||
|
|
||||||
rewriter.replaceOp(gemmOp, concatComputeOp);
|
rewriter.replaceOp(gemmOp, concatComputeOp);
|
||||||
return success();
|
return success();
|
||||||
@@ -335,9 +336,9 @@ LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVector<OpAn
|
|||||||
|
|
||||||
reducer.applyReducePattern(
|
reducer.applyReducePattern(
|
||||||
softmaxOpsToReduce,
|
softmaxOpsToReduce,
|
||||||
[&](Value a, Value b) { return rewriter.create<spatial::SpatVAddOp>(loc, scalarTensorType, a, b); },
|
[&](Value a, Value b) { return spatial::SpatVAddOp::create(rewriter, loc, scalarTensorType, a, b); },
|
||||||
/* preprocess = */
|
/* preprocess = */
|
||||||
[&](Value a) { return rewriter.create<spatial::SpatSumOp>(loc, scalarTensorType, a); },
|
[&](Value a) { return spatial::SpatSumOp::create(rewriter, loc, scalarTensorType, a); },
|
||||||
[&](Value softmaxDivisor) {
|
[&](Value softmaxDivisor) {
|
||||||
// Signal that this is the compute with the softmax divisor
|
// Signal that this is the compute with the softmax divisor
|
||||||
auto computeOp = cast<spatial::SpatWeightedCompute>(softmaxDivisor.getDefiningOp()->getParentOp());
|
auto computeOp = cast<spatial::SpatWeightedCompute>(softmaxDivisor.getDefiningOp()->getParentOp());
|
||||||
@@ -345,7 +346,7 @@ LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVector<OpAn
|
|||||||
|
|
||||||
// Broadcast the divisor to all the cores
|
// Broadcast the divisor to all the cores
|
||||||
rewriter.setInsertionPointAfterValue(softmaxDivisor);
|
rewriter.setInsertionPointAfterValue(softmaxDivisor);
|
||||||
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, softmaxChannel, softmaxDivisor);
|
spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, softmaxChannel, softmaxDivisor);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* softmaxDividend = onnx.exp (...)
|
* softmaxDividend = onnx.exp (...)
|
||||||
@@ -395,7 +396,7 @@ LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVector<OpAn
|
|||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
rewriter.setInsertionPoint(yieldOp);
|
rewriter.setInsertionPoint(yieldOp);
|
||||||
divisor = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(loc, scalarTensorType, softmaxChannel);
|
divisor = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, loc, scalarTensorType, softmaxChannel);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Walk the chain of operations until we find the ONNXExpOp: this is
|
// Walk the chain of operations until we find the ONNXExpOp: this is
|
||||||
@@ -405,7 +406,7 @@ LogicalResult GemvToSpatialCompute::softmaxReductionApplication(SmallVector<OpAn
|
|||||||
Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second));
|
Value oldOutputTile = resolveONNXExpOpFromUseChain(yieldOp->getOperand(computeToDivideOpAndResNum.second));
|
||||||
|
|
||||||
rewriter.setInsertionPoint(yieldOp);
|
rewriter.setInsertionPoint(yieldOp);
|
||||||
Value newOutputTile = rewriter.create<spatial::SpatVSDivOp>(loc, oldOutputTile.getType(), oldOutputTile, divisor);
|
Value newOutputTile = spatial::SpatVSDivOp::create(rewriter, loc, oldOutputTile.getType(), oldOutputTile, divisor);
|
||||||
auto yieldOperandNum = yieldOp->getNumOperands();
|
auto yieldOperandNum = yieldOp->getNumOperands();
|
||||||
yieldOp->insertOperands(yieldOperandNum, newOutputTile);
|
yieldOp->insertOperands(yieldOperandNum, newOutputTile);
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Utils/SpatialReducer.hpp"
|
||||||
@@ -111,15 +111,15 @@ Value applyReducePatternNew(SmallVector<Value>& valuesToReduce,
|
|||||||
|
|
||||||
// 1. Add a channel before the first computeOp
|
// 1. Add a channel before the first computeOp
|
||||||
rewriter.setInsertionPoint(firstCompute);
|
rewriter.setInsertionPoint(firstCompute);
|
||||||
auto channel = rewriter.create<spatial::SpatChannelNewOp>(loc, channelType);
|
auto channel = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
|
||||||
|
|
||||||
// 2. Add a sendOp after the first value
|
// 2. Add a sendOp after the first value
|
||||||
rewriter.setInsertionPointAfterValue(firstValue);
|
rewriter.setInsertionPointAfterValue(firstValue);
|
||||||
rewriter.create<spatial::SpatChannelSendOp>(loc, channel, firstValue);
|
spatial::SpatChannelSendOp::create(rewriter, loc, channel, firstValue);
|
||||||
|
|
||||||
// 3. Add a receiveOp after the second value
|
// 3. Add a receiveOp after the second value
|
||||||
rewriter.setInsertionPointAfterValue(secondValue);
|
rewriter.setInsertionPointAfterValue(secondValue);
|
||||||
auto receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(loc, secondValue.getType(), channel);
|
auto receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, loc, secondValue.getType(), channel);
|
||||||
|
|
||||||
// 4. Apply reduction between second value and received value
|
// 4. Apply reduction between second value and received value
|
||||||
rewriter.setInsertionPointAfterValue(receivedValue);
|
rewriter.setInsertionPointAfterValue(receivedValue);
|
||||||
@@ -188,13 +188,14 @@ Value postProcessPoolingWindow<ONNXAveragePoolOp>(ConversionPatternRewriter& rew
|
|||||||
// directly under func.func (i.e. alongside ComputeOps)
|
// directly under func.func (i.e. alongside ComputeOps)
|
||||||
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
|
auto computeOp = cast<spatial::SpatWeightedCompute>(valueToDivide.getDefiningOp()->getParentOp());
|
||||||
rewriter.setInsertionPoint(computeOp);
|
rewriter.setInsertionPoint(computeOp);
|
||||||
auto divisorValue = rewriter.create<spatial::SpatConstantOp>(loc,
|
auto divisorValue = spatial::SpatConstantOp::create(rewriter,
|
||||||
|
loc,
|
||||||
scalarTensor,
|
scalarTensor,
|
||||||
rewriter.getI64IntegerAttr(divisorNumber),
|
rewriter.getI64IntegerAttr(divisorNumber),
|
||||||
/* should_allocate = */ rewriter.getBoolAttr(true));
|
/* should_allocate = */ rewriter.getBoolAttr(true));
|
||||||
|
|
||||||
rewriter.setInsertionPointAfterValue(valueToDivide);
|
rewriter.setInsertionPointAfterValue(valueToDivide);
|
||||||
return rewriter.create<spatial::SpatVSDivOp>(loc, valueToDivide.getType(), valueToDivide, divisorValue);
|
return spatial::SpatVSDivOp::create(rewriter, loc, valueToDivide.getType(), valueToDivide, divisorValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
template <typename PoolOp, typename PoolOpAdaptor, typename ReduceOp>
|
||||||
@@ -257,7 +258,8 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
|||||||
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
|
if (auto extractSliceOp = inputTiles[t][x][y].getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||||
Location tileLoc = extractSliceOp.getLoc();
|
Location tileLoc = extractSliceOp.getLoc();
|
||||||
|
|
||||||
auto tempComputeOp = rewriter.create<spatial::SpatWeightedCompute>(tileLoc,
|
auto tempComputeOp = spatial::SpatWeightedCompute::create(rewriter,
|
||||||
|
tileLoc,
|
||||||
extractSliceOp.getResultType(),
|
extractSliceOp.getResultType(),
|
||||||
/* xbarWeights =*/ValueRange(),
|
/* xbarWeights =*/ValueRange(),
|
||||||
extractSliceOp.getResult());
|
extractSliceOp.getResult());
|
||||||
@@ -267,7 +269,7 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
|||||||
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
|
auto tempComputeOpBlockArg = tempComputeOpBlock->addArgument(extractSliceOp.getType(), tileLoc);
|
||||||
|
|
||||||
rewriter.setInsertionPointToStart(tempComputeOpBlock);
|
rewriter.setInsertionPointToStart(tempComputeOpBlock);
|
||||||
rewriter.create<spatial::SpatYieldOp>(tileLoc, tempComputeOpBlockArg);
|
spatial::SpatYieldOp::create(rewriter, tileLoc, tempComputeOpBlockArg);
|
||||||
rewriter.setInsertionPointAfter(tempComputeOp);
|
rewriter.setInsertionPointAfter(tempComputeOp);
|
||||||
inputTiles[t][x][y] = tempComputeOp.getResult(0);
|
inputTiles[t][x][y] = tempComputeOp.getResult(0);
|
||||||
}
|
}
|
||||||
@@ -356,7 +358,7 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
|||||||
Value reducedWithinCompute = applyReducePatternNew(
|
Value reducedWithinCompute = applyReducePatternNew(
|
||||||
valuesToPool,
|
valuesToPool,
|
||||||
rewriter,
|
rewriter,
|
||||||
[&](const Value lhs, const Value rhs) { return rewriter.create<ReduceOp>(loc, lhs.getType(), lhs, rhs); },
|
[&](const Value lhs, const Value rhs) { return ReduceOp::create(rewriter, loc, lhs.getType(), lhs, rhs); },
|
||||||
nullptr,
|
nullptr,
|
||||||
postProcessFn);
|
postProcessFn);
|
||||||
|
|
||||||
@@ -369,16 +371,16 @@ struct PoolingBaseConverter : public OpConversionPattern<PoolOp> {
|
|||||||
// Create a new channel before the computeOp
|
// Create a new channel before the computeOp
|
||||||
rewriter.setInsertionPoint(computeOpOfReduced);
|
rewriter.setInsertionPoint(computeOpOfReduced);
|
||||||
auto reduceChannel =
|
auto reduceChannel =
|
||||||
rewriter.create<spatial::SpatChannelNewOp>(loc, spatial::SpatChannelType::get(rewriter.getContext()));
|
spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(rewriter.getContext()));
|
||||||
|
|
||||||
// Send value through the channel
|
// Send value through the channel
|
||||||
rewriter.setInsertionPointAfterValue(reducedWithinCompute);
|
rewriter.setInsertionPointAfterValue(reducedWithinCompute);
|
||||||
rewriter.create<spatial::SpatChannelSendOp>(loc, reduceChannel, reducedWithinCompute);
|
spatial::SpatChannelSendOp::create(rewriter, loc, reduceChannel, reducedWithinCompute);
|
||||||
|
|
||||||
// Receive after the computeOp
|
// Receive after the computeOp
|
||||||
rewriter.setInsertionPointAfter(computeOpOfReduced);
|
rewriter.setInsertionPointAfter(computeOpOfReduced);
|
||||||
auto receivedValue =
|
auto receivedValue =
|
||||||
rewriter.create<spatial::SpatChannelReceiveOp>(loc, reducedWithinCompute.getType(), reduceChannel);
|
spatial::SpatChannelReceiveOp::create(rewriter, loc, reducedWithinCompute.getType(), reduceChannel);
|
||||||
|
|
||||||
outputTiles[outTile][outX][outY] = receivedValue;
|
outputTiles[outTile][outX][outY] = receivedValue;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,7 +63,8 @@ struct ReduceMeanConversionPattern : public OpConversionPattern<ONNXReduceMeanV1
|
|||||||
/*elementType=*/inputTensorType.getElementType());
|
/*elementType=*/inputTensorType.getElementType());
|
||||||
|
|
||||||
// Create the ONNXAveragePoolOp.
|
// Create the ONNXAveragePoolOp.
|
||||||
auto averagePool = rewriter.create<ONNXAveragePoolOp>(reduceMean.getLoc(),
|
auto averagePool = ONNXAveragePoolOp::create(rewriter,
|
||||||
|
reduceMean.getLoc(),
|
||||||
resultType,
|
resultType,
|
||||||
inputTensor,
|
inputTensor,
|
||||||
/*auto_pad=*/"NOTSET",
|
/*auto_pad=*/"NOTSET",
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ def matMulToGemmPattern : Pat<
|
|||||||
(ONNXMatMulOp:$matmulres $A, $B),
|
(ONNXMatMulOp:$matmulres $A, $B),
|
||||||
(
|
(
|
||||||
ONNXGemmOp $A, $B,
|
ONNXGemmOp $A, $B,
|
||||||
/* C = */ (NativeCodeCall<"$_builder.create<tensor::EmptyOp>($_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
|
/* C = */ (NativeCodeCall<"tensor::EmptyOp::create($_builder, $_loc, cast<ShapedType>(matmulres.getY().getType()).getShape(), cast<ShapedType>(matmulres.getY().getType()).getElementType());">),
|
||||||
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
/* alpha = */ (NativeCodeCall<"$_builder.getF32FloatAttr(1)">),
|
||||||
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
|
/* beta = */ (NativeCodeCall<"$_builder.getF32FloatAttr(0)">),
|
||||||
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
/* transA = */ (NativeCodeCall<"IntegerAttr::get($_builder.getIntegerType(64, true), 0)">),
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ SmallVector<Value> sliceTensor(
|
|||||||
if (i == numSlices - 1 && lastSliceSize != 0)
|
if (i == numSlices - 1 && lastSliceSize != 0)
|
||||||
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
|
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
|
||||||
|
|
||||||
Value slice = rewriter.create<tensor::ExtractSliceOp>(loc, tensorToSlice, offsets, sizes, strides);
|
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
||||||
slices.push_back(slice);
|
slices.push_back(slice);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,11 +100,11 @@ broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewr
|
|||||||
int64_t shape[2] = {1, length};
|
int64_t shape[2] = {1, length};
|
||||||
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
||||||
|
|
||||||
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
|
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||||
SmallVector<Value> index(oldType.getRank(), zero);
|
SmallVector<Value> index(oldType.getRank(), zero);
|
||||||
auto elementValue = rewriter.create<tensor::ExtractOp>(loc, scalarToBroadcast, index).getResult();
|
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
|
||||||
|
|
||||||
return rewriter.create<tensor::SplatOp>(loc, type, elementValue);
|
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
||||||
@@ -122,7 +122,7 @@ Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
|||||||
Value a = (*currTensors)[i];
|
Value a = (*currTensors)[i];
|
||||||
Value b = (*currTensors)[i + 1];
|
Value b = (*currTensors)[i + 1];
|
||||||
rewriter.setInsertionPointAfterValue(b);
|
rewriter.setInsertionPointAfterValue(b);
|
||||||
auto addedValue = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b);
|
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
|
||||||
nextTensors->push_back(addedValue);
|
nextTensors->push_back(addedValue);
|
||||||
}
|
}
|
||||||
if (currTensors->size() % 2 == 1)
|
if (currTensors->size() % 2 == 1)
|
||||||
@@ -137,10 +137,10 @@ Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
|||||||
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input) {
|
Value createMapOperation(PatternRewriter& rewriter, MapOperations mapOp, const Value& input) {
|
||||||
switch (mapOp) {
|
switch (mapOp) {
|
||||||
case MapOperations::None: assert(false && "Invalid map operation during map operation creation.");
|
case MapOperations::None: assert(false && "Invalid map operation during map operation creation.");
|
||||||
case MapOperations::ONNXSoftmaxOp: return rewriter.create<ONNXSoftmaxOp>(input.getLoc(), input.getType(), input);
|
case MapOperations::ONNXSoftmaxOp: return ONNXSoftmaxOp::create(rewriter, input.getLoc(), input.getType(), input);
|
||||||
case MapOperations::ONNXReluOp: return rewriter.create<ONNXReluOp>(input.getLoc(), input.getType(), input);
|
case MapOperations::ONNXReluOp: return ONNXReluOp::create(rewriter, input.getLoc(), input.getType(), input);
|
||||||
case MapOperations::ONNXLeakyReluOp: return rewriter.create<ONNXLeakyReluOp>(input.getLoc(), input.getType(), input);
|
case MapOperations::ONNXLeakyReluOp: return ONNXLeakyReluOp::create(rewriter, input.getLoc(), input.getType(), input);
|
||||||
case MapOperations::ONNXExpOp: return rewriter.create<ONNXExpOp>(input.getLoc(), input.getType(), input);
|
case MapOperations::ONNXExpOp: return ONNXExpOp::create(rewriter, input.getLoc(), input.getType(), input);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,7 +201,7 @@ void tileImageTensorByChannel(Value imageTensor,
|
|||||||
offsets[2] = rewriter.getIndexAttr(x);
|
offsets[2] = rewriter.getIndexAttr(x);
|
||||||
offsets[3] = rewriter.getIndexAttr(y);
|
offsets[3] = rewriter.getIndexAttr(y);
|
||||||
|
|
||||||
tiles[i][x][y] = rewriter.create<tensor::ExtractSliceOp>(loc, imageTensor, offsets, sizes, strides);
|
tiles[i][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, imageTensor, offsets, sizes, strides);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -225,7 +225,7 @@ Value createImgConcatOp(SmallVector<SmallVector<SmallVector<Value>>>& outputTile
|
|||||||
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
|
for (size_t outTile = 0; outTile < outputTileCount; outTile++)
|
||||||
tilesToConcat.push_back(outputTiles[outTile][outX][outY]);
|
tilesToConcat.push_back(outputTiles[outTile][outX][outY]);
|
||||||
|
|
||||||
return rewriter.create<spatial::SpatImgConcatOp>(loc, outputType, tilesToConcat);
|
return spatial::SpatImgConcatOp::create(rewriter, loc, outputType, tilesToConcat);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult
|
LogicalResult
|
||||||
@@ -271,7 +271,7 @@ Value createExtractSliceImg(Value valToSlice,
|
|||||||
offsets[2] = rewriter.getIndexAttr(x);
|
offsets[2] = rewriter.getIndexAttr(x);
|
||||||
offsets[3] = rewriter.getIndexAttr(y);
|
offsets[3] = rewriter.getIndexAttr(y);
|
||||||
|
|
||||||
return rewriter.create<tensor::ExtractSliceOp>(valToSlice.getLoc(), valToSlice, offsets, sizes, strides);
|
return tensor::ExtractSliceOp::create(rewriter, valToSlice.getLoc(), valToSlice, offsets, sizes, strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
Value indexImgValue(Value v,
|
Value indexImgValue(Value v,
|
||||||
@@ -384,7 +384,7 @@ void resolveInputTensorTilesBlockArg(Value wholeInputTensor,
|
|||||||
offsets[2] = rewriter.getIndexAttr(x);
|
offsets[2] = rewriter.getIndexAttr(x);
|
||||||
offsets[3] = rewriter.getIndexAttr(y);
|
offsets[3] = rewriter.getIndexAttr(y);
|
||||||
|
|
||||||
inputTiles[t][x][y] = rewriter.create<tensor::ExtractSliceOp>(loc, wholeInputTensor, offsets, sizes, strides);
|
inputTiles[t][x][y] = tensor::ExtractSliceOp::create(rewriter, loc, wholeInputTensor, offsets, sizes, strides);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -452,7 +452,7 @@ LogicalResult handleFlattenLikeOp(SmallVector<SmallVector<Value>>& inputTiles,
|
|||||||
SmallVector<int64_t> newShapeVals = {curTileShape.getDimSize(0), curTileShape.getDimSize(1)};
|
SmallVector<int64_t> newShapeVals = {curTileShape.getDimSize(0), curTileShape.getDimSize(1)};
|
||||||
auto shapeType = RankedTensorType::get({static_cast<int64_t>(newShapeVals.size())}, rewriter.getI64Type());
|
auto shapeType = RankedTensorType::get({static_cast<int64_t>(newShapeVals.size())}, rewriter.getI64Type());
|
||||||
Value shapeTensor =
|
Value shapeTensor =
|
||||||
rewriter.create<arith::ConstantOp>(reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals));
|
arith::ConstantOp::create(rewriter, reshapeInput.getLoc(), DenseIntElementsAttr::get(shapeType, newShapeVals));
|
||||||
auto reshapedType = RankedTensorType::get(newShapeVals, curTileShape.getElementType());
|
auto reshapedType = RankedTensorType::get(newShapeVals, curTileShape.getElementType());
|
||||||
auto reshapedCurTile = tosa::ReshapeOp::create(rewriter, reshapeInput.getLoc(), reshapedType, curTile, shapeTensor);
|
auto reshapedCurTile = tosa::ReshapeOp::create(rewriter, reshapeInput.getLoc(), reshapedType, curTile, shapeTensor);
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
||||||
#include "Common/PIMCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
#include "Conversion/ONNXToSpatial/Utils/AnnotateReplication.hpp"
|
||||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPatterns.hpp"
|
||||||
|
|||||||
@@ -272,8 +272,8 @@ void SpatialReducer::updateResultsOfCompute(mlir::Operation* computeOp) {
|
|||||||
|
|
||||||
// Create a new ComputeOp with the new result type, but same operands
|
// Create a new ComputeOp with the new result type, but same operands
|
||||||
rewriter.setInsertionPoint(oldComputeOp);
|
rewriter.setInsertionPoint(oldComputeOp);
|
||||||
auto newComputeOp = rewriter.create<spatial::SpatWeightedCompute>(
|
auto newComputeOp = spatial::SpatWeightedCompute::create(
|
||||||
oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs());
|
rewriter, oldComputeOp->getLoc(), newResultTypes, oldComputeOp.getWeights(), oldComputeOp.getInputs());
|
||||||
|
|
||||||
newComputeOp.getBody().takeBody(oldComputeOp.getBody());
|
newComputeOp.getBody().takeBody(oldComputeOp.getBody());
|
||||||
|
|
||||||
@@ -333,14 +333,14 @@ OpAndResNum SpatialReducer::applyAddMapReduction(llvm::SmallVector<ComputeAndRes
|
|||||||
postprocessing = [&](const mlir::Value a) {
|
postprocessing = [&](const mlir::Value a) {
|
||||||
mlir::Value mapOperand = a;
|
mlir::Value mapOperand = a;
|
||||||
if (biasTile)
|
if (biasTile)
|
||||||
mapOperand = rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, biasTile);
|
mapOperand = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, biasTile);
|
||||||
return createMapOperation(rewriter, mapOp, mapOperand);
|
return createMapOperation(rewriter, mapOp, mapOperand);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
return this->applyReducePattern(
|
return this->applyReducePattern(
|
||||||
computeOps,
|
computeOps,
|
||||||
[&](mlir::Value a, mlir::Value b) { return rewriter.create<spatial::SpatVAddOp>(a.getLoc(), a.getType(), a, b); },
|
[&](mlir::Value a, mlir::Value b) { return spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b); },
|
||||||
/* preprocess = */ nullptr,
|
/* preprocess = */ nullptr,
|
||||||
postprocessing);
|
postprocessing);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ add_onnx_mlir_library(OMSpatialToGraphviz
|
|||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
OMCompilerOptions
|
OMCompilerOptions
|
||||||
OMPIMCommon
|
OMPimCommon
|
||||||
OMONNXOps
|
OMONNXOps
|
||||||
SpatialOps
|
SpatialOps
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
#include "llvm/Support/Casting.h"
|
#include "llvm/Support/Casting.h"
|
||||||
#include "llvm/Support/Format.h"
|
#include "llvm/Support/Format.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
set(LLVM_TARGET_DEFINITIONS SpatialToPIM.td)
|
|
||||||
mlir_tablegen(SpatialToPIM.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
|
||||||
add_public_tablegen_target(SpatialToPIMIncGen)
|
|
||||||
|
|
||||||
add_onnx_mlir_library(OMSpatialToPIM
|
|
||||||
SpatialToPIMPass.cpp
|
|
||||||
SpatialToPIMCommon.cpp
|
|
||||||
|
|
||||||
DEPENDS
|
|
||||||
SpatialToPIMIncGen
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
|
||||||
OMCompilerOptions
|
|
||||||
OMPIMCommon
|
|
||||||
SpatialOps
|
|
||||||
PimOps
|
|
||||||
|
|
||||||
ACCEL_INCLUDE_DIRS PRIVATE
|
|
||||||
${PIM_INCLUDE_PATH}
|
|
||||||
)
|
|
||||||
20
src/PIM/Conversion/SpatialToPim/CMakeLists.txt
Normal file
20
src/PIM/Conversion/SpatialToPim/CMakeLists.txt
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
set(LLVM_TARGET_DEFINITIONS SpatialToPim.td)
|
||||||
|
mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||||
|
add_public_tablegen_target(SpatialToPimIncGen)
|
||||||
|
|
||||||
|
add_onnx_mlir_library(OMSpatialToPim
|
||||||
|
SpatialToPimPass.cpp
|
||||||
|
SpatialToPimCommon.cpp
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
SpatialToPimIncGen
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
OMCompilerOptions
|
||||||
|
OMPimCommon
|
||||||
|
SpatialOps
|
||||||
|
PimOps
|
||||||
|
|
||||||
|
ACCEL_INCLUDE_DIRS PRIVATE
|
||||||
|
${PIM_INCLUDE_PATH}
|
||||||
|
)
|
||||||
@@ -5,7 +5,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
|
||||||
#include "SpatialToPIMCommon.hpp"
|
#include "SpatialToPimCommon.hpp"
|
||||||
|
|
||||||
using namespace llvm;
|
using namespace llvm;
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -18,8 +18,8 @@
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
#include "Conversion/ONNXToSpatial/ONNXToSpatialCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIMCommon.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
#include "src/Accelerators/PIM/Pass/PimPasses.hpp"
|
||||||
@@ -33,15 +33,15 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPIM/SpatialToPIM.hpp.inc"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
||||||
|
|
||||||
struct SpatialToPIMPass : PassWrapper<SpatialToPIMPass, OperationPass<ModuleOp>> {
|
struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>> {
|
||||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPIMPass)
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SpatialToPimPass)
|
||||||
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
StringRef getArgument() const override { return "convert-spatial-to-pim"; }
|
||||||
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
StringRef getDescription() const override { return "Lower Spatial ops to PIM-ready format"; }
|
||||||
|
|
||||||
SpatialToPIMPass() = default;
|
SpatialToPimPass() = default;
|
||||||
SpatialToPIMPass(const SpatialToPIMPass& pass) {}
|
SpatialToPimPass(const SpatialToPimPass& pass) {}
|
||||||
|
|
||||||
void runOnOperation() final;
|
void runOnOperation() final;
|
||||||
|
|
||||||
@@ -76,7 +76,7 @@ private:
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void SpatialToPIMPass::runOnOperation() {
|
void SpatialToPimPass::runOnOperation() {
|
||||||
coreId = 1;
|
coreId = 1;
|
||||||
ModuleOp moduleOp = getOperation();
|
ModuleOp moduleOp = getOperation();
|
||||||
MLIRContext* ctx = moduleOp.getContext();
|
MLIRContext* ctx = moduleOp.getContext();
|
||||||
@@ -132,7 +132,7 @@ void SpatialToPIMPass::runOnOperation() {
|
|||||||
dumpModule(moduleOp, "pim");
|
dumpModule(moduleOp, "pim");
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IRRewriter& rewriter) {
|
||||||
Location loc = computeOp->getLoc();
|
Location loc = computeOp->getLoc();
|
||||||
|
|
||||||
auto& block = computeOp.getRegion().front();
|
auto& block = computeOp.getRegion().front();
|
||||||
@@ -180,7 +180,8 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
|||||||
// Store to global memory
|
// Store to global memory
|
||||||
Value outputTensor = outputTensors[resultIndexInReturn];
|
Value outputTensor = outputTensors[resultIndexInReturn];
|
||||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||||
rewriter.create<PimMemCopyDevToHostOp>(loc,
|
PimMemCopyDevToHostOp::create(rewriter,
|
||||||
|
loc,
|
||||||
outputTensor.getType(),
|
outputTensor.getType(),
|
||||||
outputTensor,
|
outputTensor,
|
||||||
yieldValue,
|
yieldValue,
|
||||||
@@ -211,7 +212,7 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
|||||||
// Store to global memory
|
// Store to global memory
|
||||||
Value outputTensor = outputTensors[concatIndexInReturn];
|
Value outputTensor = outputTensors[concatIndexInReturn];
|
||||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||||
rewriter.create<PimMemCopyDevToHostOp>(
|
PimMemCopyDevToHostOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
outputTensor.getType(),
|
outputTensor.getType(),
|
||||||
outputTensor,
|
outputTensor,
|
||||||
@@ -230,7 +231,7 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
|||||||
// 1. Create a new ChannelOp
|
// 1. Create a new ChannelOp
|
||||||
rewriter.setInsertionPoint(computeOp);
|
rewriter.setInsertionPoint(computeOp);
|
||||||
auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
|
auto channelType = spatial::SpatChannelType::get(computeOp.getContext());
|
||||||
auto channelOp = rewriter.create<spatial::SpatChannelNewOp>(loc, channelType);
|
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, channelType);
|
||||||
|
|
||||||
// 2. Receive value through the channel
|
// 2. Receive value through the channel
|
||||||
// If this result is used by more than one user, then use a "Broadcast"
|
// If this result is used by more than one user, then use a "Broadcast"
|
||||||
@@ -244,9 +245,9 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
|||||||
// 3. Send the value through the channel
|
// 3. Send the value through the channel
|
||||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||||
if (useBroadcastOp)
|
if (useBroadcastOp)
|
||||||
rewriter.create<spatial::SpatChannelBroadcastSendOp>(loc, channelOp, yieldValue);
|
spatial::SpatChannelBroadcastSendOp::create(rewriter, loc, channelOp, yieldValue);
|
||||||
else
|
else
|
||||||
rewriter.create<spatial::SpatChannelSendOp>(loc, channelOp, yieldValue);
|
spatial::SpatChannelSendOp::create(rewriter, loc, channelOp, yieldValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use `HaltOp` instead of `YieldOp`
|
// Use `HaltOp` instead of `YieldOp`
|
||||||
@@ -255,17 +256,17 @@ void SpatialToPIMPass::runOnComputeOp(spatial::SpatWeightedCompute computeOp, IR
|
|||||||
|
|
||||||
// Replace `spat.compute` with `pim.core`
|
// Replace `spat.compute` with `pim.core`
|
||||||
rewriter.setInsertionPointAfter(computeOp);
|
rewriter.setInsertionPointAfter(computeOp);
|
||||||
auto coreOp = rewriter.create<PimCoreOp>(loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++));
|
auto coreOp = PimCoreOp::create(rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(coreId++));
|
||||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||||
block.eraseArguments(0, block.getNumArguments());
|
block.eraseArguments(0, block.getNumArguments());
|
||||||
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
|
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
|
||||||
Block* tempComputeBlock = new Block();
|
Block* tempComputeBlock = new Block();
|
||||||
computeOp.getBody().push_back(tempComputeBlock);
|
computeOp.getBody().push_back(tempComputeBlock);
|
||||||
rewriter.setInsertionPointToEnd(tempComputeBlock);
|
rewriter.setInsertionPointToEnd(tempComputeBlock);
|
||||||
rewriter.create<PimHaltOp>(computeOp.getLoc());
|
PimHaltOp::create(rewriter, computeOp.getLoc());
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
|
auto enlargeTiedDpsChain = [&](Value value, RankedTensorType newType, auto& self) -> void {
|
||||||
auto* definingOp = value.getDefiningOp();
|
auto* definingOp = value.getDefiningOp();
|
||||||
if (!definingOp)
|
if (!definingOp)
|
||||||
@@ -302,14 +303,14 @@ void SpatialToPIMPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
|
|||||||
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
|
SmallVector<OpFoldResult> sizes = {oldShapeZeroAttr, oldShapeOneAttr};
|
||||||
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
|
SmallVector<OpFoldResult> strides = {oneAttr, oneAttr};
|
||||||
rewriter.setInsertionPointAfter(vmmOp);
|
rewriter.setInsertionPointAfter(vmmOp);
|
||||||
auto sliceOp = rewriter.create<tensor::ExtractSliceOp>(vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
|
auto sliceOp = tensor::ExtractSliceOp::create(rewriter, vmmOp.getLoc(), resultTensor, offsets, sizes, strides);
|
||||||
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
SmallPtrSet<Operation*, 2> exceptions = {vmmOp, sliceOp};
|
||||||
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
resultTensor.replaceAllUsesExcept(sliceOp.getResult(), exceptions);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||||
outputTensors.reserve(returnOp->getNumOperands());
|
outputTensors.reserve(returnOp->getNumOperands());
|
||||||
rewriter.setInsertionPointToStart(returnOp->getBlock());
|
rewriter.setInsertionPointToStart(returnOp->getBlock());
|
||||||
for (auto returnValue : returnOp->getOperands()) {
|
for (auto returnValue : returnOp->getOperands()) {
|
||||||
@@ -326,7 +327,7 @@ void SpatialToPIMPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
|
||||||
Location loc = funcOp.getLoc();
|
Location loc = funcOp.getLoc();
|
||||||
|
|
||||||
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
|
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) {
|
||||||
@@ -335,9 +336,10 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
|||||||
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
|
||||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace));
|
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace));
|
||||||
|
|
||||||
auto deviceTensor = rewriter.create<tensor::EmptyOp>(loc, tensorType.getShape(), elementType);
|
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
|
||||||
|
|
||||||
auto memCopyHostToDevOp = rewriter.create<PimMemCopyHostToDevOp>(
|
auto memCopyHostToDevOp = PimMemCopyHostToDevOp::create(
|
||||||
|
rewriter,
|
||||||
loc,
|
loc,
|
||||||
tensorType,
|
tensorType,
|
||||||
deviceTensor,
|
deviceTensor,
|
||||||
@@ -362,7 +364,8 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
|||||||
|
|
||||||
Block& block = funcOp.getBody().front();
|
Block& block = funcOp.getBody().front();
|
||||||
rewriter.setInsertionPoint(&block.front());
|
rewriter.setInsertionPoint(&block.front());
|
||||||
auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
|
auto toTensorOp =
|
||||||
|
bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
|
||||||
inputTensors.push_back(toTensorOp);
|
inputTensors.push_back(toTensorOp);
|
||||||
|
|
||||||
tensorArg.replaceAllUsesWith(toTensorOp);
|
tensorArg.replaceAllUsesWith(toTensorOp);
|
||||||
@@ -415,7 +418,7 @@ void SpatialToPIMPass::allocateAndInitializeCoreLocalVariables(func::FuncOp func
|
|||||||
rewriter.eraseOp(sliceOp);
|
rewriter.eraseOp(sliceOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
void SpatialToPimPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompute& computeOp,
|
||||||
unsigned int argIndex,
|
unsigned int argIndex,
|
||||||
spatial::SpatChannelNewOp& channel,
|
spatial::SpatChannelNewOp& channel,
|
||||||
Type& tensorType,
|
Type& tensorType,
|
||||||
@@ -431,14 +434,14 @@ void SpatialToPIMPass::replaceBlockArgumentWithRecvOp(spatial::SpatWeightedCompu
|
|||||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||||
Value receivedValue;
|
Value receivedValue;
|
||||||
if (useBroadcastOp)
|
if (useBroadcastOp)
|
||||||
receivedValue = rewriter.create<spatial::SpatChannelBroadcastReceiveOp>(computeOp.getLoc(), tensorType, channel);
|
receivedValue = spatial::SpatChannelBroadcastReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel);
|
||||||
else
|
else
|
||||||
receivedValue = rewriter.create<spatial::SpatChannelReceiveOp>(computeOp.getLoc(), tensorType, channel);
|
receivedValue = spatial::SpatChannelReceiveOp::create(rewriter, computeOp.getLoc(), tensorType, channel);
|
||||||
|
|
||||||
blockArg.replaceAllUsesWith(receivedValue);
|
blockArg.replaceAllUsesWith(receivedValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPIMPass::addReceiveOps(Value& channelSourceOp,
|
void SpatialToPimPass::addReceiveOps(Value& channelSourceOp,
|
||||||
spatial::SpatChannelNewOp& channel,
|
spatial::SpatChannelNewOp& channel,
|
||||||
Type& channelTensorType,
|
Type& channelTensorType,
|
||||||
bool& useBroadcastOp,
|
bool& useBroadcastOp,
|
||||||
@@ -495,7 +498,7 @@ void SpatialToPIMPass::addReceiveOps(Value& channelSourceOp,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
|
||||||
for (auto it : llvm::enumerate(returnOp.getOperands())) {
|
for (auto it : llvm::enumerate(returnOp.getOperands())) {
|
||||||
Operation* returnOperand = it.value().getDefiningOp();
|
Operation* returnOperand = it.value().getDefiningOp();
|
||||||
|
|
||||||
@@ -514,7 +517,7 @@ void SpatialToPIMPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
|
void SpatialToPimPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, IRRewriter& rewriter) {
|
||||||
|
|
||||||
auto channel = cast<spatial::SpatChannelNewOp>(receiveOp.getChannel().getDefiningOp());
|
auto channel = cast<spatial::SpatChannelNewOp>(receiveOp.getChannel().getDefiningOp());
|
||||||
|
|
||||||
@@ -543,6 +546,6 @@ void SpatialToPIMPass::runOnReceiveOp(spatial::SpatChannelReceiveOp receiveOp, I
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Pass> createSpatialToPIMPass() { return std::make_unique<SpatialToPIMPass>(); }
|
std::unique_ptr<Pass> createSpatialToPimPass() { return std::make_unique<SpatialToPimPass>(); }
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
@@ -13,7 +13,7 @@ add_onnx_mlir_library(OMPimBufferization
|
|||||||
PimBufferizationIncGen
|
PimBufferizationIncGen
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
OMPIMCommon
|
OMPimCommon
|
||||||
PimOps
|
PimOps
|
||||||
|
|
||||||
ACCEL_INCLUDE_DIRS PRIVATE
|
ACCEL_INCLUDE_DIRS PRIVATE
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
#include "Common/PIMCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Compiler/PimCodeGen.hpp"
|
#include "Compiler/PimCodeGen.hpp"
|
||||||
#include "Dialect/Pim/PimOps.hpp"
|
#include "Dialect/Pim/PimOps.hpp"
|
||||||
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/Transforms/SpatialBufferizableOpInterface.hpp"
|
||||||
@@ -34,7 +34,7 @@ memref::AllocOp createEmptyFromType(Type resultType, Location loc, RewriterBase&
|
|||||||
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
|
auto memrefResultType = MemRefType::get(resultShape.getShape(), resultShape.getElementType());
|
||||||
|
|
||||||
// Alloc an output memref
|
// Alloc an output memref
|
||||||
return rewriter.create<memref::AllocOp>(loc, memrefResultType);
|
return memref::AllocOp::create(rewriter, loc, memrefResultType);
|
||||||
}
|
}
|
||||||
|
|
||||||
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
const llvm::StringRef PRECOMPUTED_OTHER_CORE_ID_ATTR_NAME("precomp_other_core_id");
|
||||||
@@ -134,7 +134,7 @@ struct VariadicArgumentElementWiseOpInterface : BufferizableOpInterface::Externa
|
|||||||
|
|
||||||
memrefOperands.push_back(outputTensor);
|
memrefOperands.push_back(outputTensor);
|
||||||
|
|
||||||
Value newValue = rewriter.create<ToTy>(op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
|
Value newValue = ToTy::create(rewriter, op->getLoc(), outputTensor.getType(), memrefOperands).getOutRes();
|
||||||
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||||
|
|
||||||
@@ -169,10 +169,12 @@ struct WeightedMultiplicationsOpInterface : BufferizableOpInterface::ExternalMod
|
|||||||
// Alloc an output memref
|
// Alloc an output memref
|
||||||
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
Value outputTensor = createEmptyFromType(op->getResult(0).getType(), op->getLoc(), rewriter);
|
||||||
|
|
||||||
Value newValue =
|
Value newValue = ToTy::create(rewriter,
|
||||||
rewriter
|
op->getLoc(),
|
||||||
.create<ToTy>(
|
outputTensor.getType(),
|
||||||
op->getLoc(), outputTensor.getType(), cast<OpTy>(op).getWeightIndexAttr(), memrefOperand, outputTensor)
|
cast<OpTy>(op).getWeightIndexAttr(),
|
||||||
|
memrefOperand,
|
||||||
|
outputTensor)
|
||||||
.getOutRes();
|
.getOutRes();
|
||||||
|
|
||||||
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
replaceOpWithBufferizedValues(rewriter, op, newValue);
|
||||||
@@ -213,8 +215,8 @@ struct ChannelReceiveOpInterface
|
|||||||
if (failed(srcCoreId))
|
if (failed(srcCoreId))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
Value newValue = rewriter
|
Value newValue = pim::PimReceiveOp::create(rewriter,
|
||||||
.create<pim::PimReceiveOp>(op->getLoc(),
|
op->getLoc(),
|
||||||
outputTensor.getType(),
|
outputTensor.getType(),
|
||||||
outputTensor,
|
outputTensor,
|
||||||
rewriter.getI32IntegerAttr(numElements * elementSize),
|
rewriter.getI32IntegerAttr(numElements * elementSize),
|
||||||
@@ -324,7 +326,8 @@ struct ChannelBroadcastReceiveOpInterface
|
|||||||
}
|
}
|
||||||
|
|
||||||
rewriter.setInsertionPoint(op);
|
rewriter.setInsertionPoint(op);
|
||||||
auto memCopyHostToDevOp = rewriter.create<pim::PimMemCopyHostToDevOp>(op->getLoc(),
|
auto memCopyHostToDevOp = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||||
|
op->getLoc(),
|
||||||
outputTensor.getType(),
|
outputTensor.getType(),
|
||||||
outputTensor,
|
outputTensor,
|
||||||
bufferAllocation,
|
bufferAllocation,
|
||||||
@@ -395,7 +398,8 @@ struct ChannelBroadcastSendOpInterface
|
|||||||
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
|
auto sizeInBytes = srcType.getNumElements() * srcType.getElementTypeBitWidth() / 8;
|
||||||
|
|
||||||
rewriter.setInsertionPoint(op);
|
rewriter.setInsertionPoint(op);
|
||||||
rewriter.create<pim::PimMemCopyDevToHostOp>(op->getLoc(),
|
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||||
|
op->getLoc(),
|
||||||
bufferAllocation.getType(),
|
bufferAllocation.getType(),
|
||||||
bufferAllocation,
|
bufferAllocation,
|
||||||
srcMemRef,
|
srcMemRef,
|
||||||
@@ -481,7 +485,8 @@ struct ApplyFiltersOpInterface : BufferizableOpInterface::ExternalModel<ApplyFil
|
|||||||
auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr();
|
auto xKernelPositions = cast<SpatApplyFiltersOp>(op).getXKernelPositionsAttr();
|
||||||
auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr();
|
auto yKernelPositions = cast<SpatApplyFiltersOp>(op).getYKernelPositionsAttr();
|
||||||
|
|
||||||
Value bufferized = rewriter.create<pim::PimApplyFiltersOp>(op->getLoc(),
|
Value bufferized = pim::PimApplyFiltersOp::create(rewriter,
|
||||||
|
op->getLoc(),
|
||||||
outputTensor.getType(),
|
outputTensor.getType(),
|
||||||
weightIndices,
|
weightIndices,
|
||||||
xKernelPositions,
|
xKernelPositions,
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
|
|
||||||
#include "Common/PIMCommon.hpp"
|
#include "Common/PimCommon.hpp"
|
||||||
#include "Compiler/PimCodeGen.hpp"
|
#include "Compiler/PimCodeGen.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
||||||
#include "src/Accelerators/PIM/Common/PIMCommon.hpp"
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ std::unique_ptr<mlir::Pass> createONNXToSpatialPass();
|
|||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
|
std::unique_ptr<mlir::Pass> createSpatialToGraphvizPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createSpatialToPIMPass();
|
std::unique_ptr<mlir::Pass> createSpatialToPimPass();
|
||||||
|
|
||||||
std::unique_ptr<mlir::Pass> createBufferizePimPass();
|
std::unique_ptr<mlir::Pass> createBufferizePimPass();
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ void PimAccelerator::registerPasses(int optLevel) const {
|
|||||||
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
|
LLVM_DEBUG(llvm::dbgs() << "Registering passes for PIM accelerator\n");
|
||||||
registerPass(createONNXToSpatialPass);
|
registerPass(createONNXToSpatialPass);
|
||||||
registerPass(createSpatialToGraphvizPass);
|
registerPass(createSpatialToGraphvizPass);
|
||||||
registerPass(createSpatialToPIMPass);
|
registerPass(createSpatialToPimPass);
|
||||||
registerPass(createBufferizePimPass);
|
registerPass(createBufferizePimPass);
|
||||||
registerPass(createPimFoldHostConstantsPass);
|
registerPass(createPimFoldHostConstantsPass);
|
||||||
registerPass(createPimHostVerificationPass);
|
registerPass(createPimHostVerificationPass);
|
||||||
|
|||||||
Reference in New Issue
Block a user