Compare commits

...

3 Commits

Author SHA1 Message Date
ilgeco
08b0fcd850 Parallel bufferization
All checks were successful
Validate Operations / validate-operations (push) Successful in 21m49s
2026-04-30 11:48:17 +02:00
ilgeco
9dccc2c701 Translate global constant to symble 2026-04-28 12:42:01 +02:00
ilgeco
5c839e62c1 Func Input converted to symbol 2026-04-27 13:48:03 +02:00
12 changed files with 584 additions and 135 deletions

View File

@@ -1,4 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"

View File

@@ -3,9 +3,11 @@
#include "mlir/IR/Attributes.h" #include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FileSystem.h" #include "llvm/Support/FileSystem.h"
#include "llvm/Support/JSON.h" #include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
@@ -53,9 +55,23 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) { void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants; SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases; SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
SmallVector<mlir::Value> args;
for (mlir::Value arg : funcOp.getArguments()){
gatherMemEntry(arg);
args.push_back(arg);
}
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) { funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (!hasWeightAlways(getGlobalOp)) { if (!hasWeightAlways(getGlobalOp)) {
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (globalMemrefOp.getName().starts_with("arg")){
StringRef indexStr = globalMemrefOp.getName().substr(4);
int index = 0;
llvm::to_integer(indexStr,index, 10);
globalAliases.push_back({getGlobalOp.getResult(), args[index]});
}
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult()); auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
if (inserted) if (inserted)
gatherMemEntry(getGlobalOp.getResult()); gatherMemEntry(getGlobalOp.getResult());
@@ -64,8 +80,6 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
} }
}); });
for (mlir::Value arg : funcOp.getArguments())
gatherMemEntry(arg);
funcOp.walk([&](memref::AllocOp allocOp) { funcOp.walk([&](memref::AllocOp allocOp) {
if (!allocOp->getParentOfType<pim::PimCoreOp>()) if (!allocOp->getParentOfType<pim::PimCoreOp>())
@@ -412,6 +426,9 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa
emitInstruction(std::move(json)); emitInstruction(std::move(json));
} }
void PimCodeGen::codeGetGlobalOp(memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const {
}
void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const { void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const {
auto srcAddr = addressOf(transposeOp.getInput(), knowledge); auto srcAddr = addressOf(transposeOp.getInput(), knowledge);
auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge); auto dstAddr = addressOf(transposeOp.getOutputBuffer(), knowledge);
@@ -581,6 +598,8 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) {
coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge); coreCodeGen.codeGenVSigmOp(vsigmOp, knowledge);
else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op)) else if (auto vsoftmaxOp = dyn_cast<pim::PimVSoftmaxOp>(op))
coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge); coreCodeGen.codeGenVSoftmaxOp(vsoftmaxOp, knowledge);
else if (auto getGlobalOp = dyn_cast<memref::GetGlobalOp>(op))
coreCodeGen.codeGetGlobalOp(getGlobalOp, knowledge);
else { else {
op.emitError("Unsupported codegen for this operation"); op.emitError("Unsupported codegen for this operation");
op.dump(); op.dump();

View File

@@ -106,6 +106,7 @@ public:
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const; void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const; void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const; void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
void codeGetGlobalOp(mlir::memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const; void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
}; };

View File

@@ -1,3 +1,4 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -11,6 +12,7 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h" #include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_os_ostream.h" #include "llvm/Support/raw_os_ostream.h"
#include <fstream> #include <fstream>
@@ -144,6 +146,7 @@ void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n"; llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
annotateWeightsConstants(*entryFunc); annotateWeightsConstants(*entryFunc);
encapsulateGlobalInstruction(*entryFunc); encapsulateGlobalInstruction(*entryFunc);
if (failed(promoteConstantInputsToWeights(*entryFunc))) { if (failed(promoteConstantInputsToWeights(*entryFunc))) {
@@ -160,19 +163,36 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) { if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
Value source = funcSource(toRemoveOp); Value source = funcSource(toRemoveOp);
rewriter.setInsertionPointAfter(toRemoveOp); rewriter.setInsertionPointAfter(toRemoveOp);
if (isa_and_present<spatial::SpatCompute>(source.getDefiningOp())) { auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source); auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc}); newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1}); rewriter.setInsertionPointToEnd(BB);
rewriter.setInsertionPointToEnd(BB); IRMapping mapper;
IRMapping mapper; mapper.map(source, BB->getArgument(0));
mapper.map(source, BB->getArgument(0)); auto newInst = rewriter.clone(*inst, mapper);
auto newInst = rewriter.clone(*inst, mapper); spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults()); inst->replaceAllUsesWith(newCompute->getResults());
inst->replaceAllUsesWith(newCompute->getResults()); inst->erase();
inst->erase(); return true;
return true; }
} return false;
}
bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) {
if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present<tensor::ExtractSliceOp>(inst)) {
auto source = toRemoveOp.getSource();
rewriter.setInsertionPointAfter(toRemoveOp);
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
} }
return false; return false;
} }
@@ -181,27 +201,24 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) { if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
auto sources = toRemoveOp.getInputs(); auto sources = toRemoveOp.getInputs();
rewriter.setInsertionPointAfter(toRemoveOp); rewriter.setInsertionPointAfter(toRemoveOp);
if (llvm::any_of(sources, auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
[](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) { SmallVector<Type> sourceTypes;
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources); SmallVector<Location> sourceLoc;
SmallVector<Type> sourceTypes; for (auto source : sources) {
SmallVector<Location> sourceLoc; sourceTypes.push_back(source.getType());
for (auto source : sources) { sourceLoc.push_back(loc);
sourceTypes.push_back(source.getType());
sourceLoc.push_back(loc);
}
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
mapper.map(source, bbArg);
auto newConcat = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
} }
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
for (auto [source, bbArg] : llvm::zip(sources, BB->getArguments()))
mapper.map(source, bbArg);
auto newConcat = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResults());
inst->replaceAllUsesWith(newCompute->getResults());
inst->erase();
return true;
} }
return false; return false;
} }
@@ -263,6 +280,72 @@ static FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewrite
return cast<Value>(mapped); return cast<Value>(mapped);
} }
bool sourceOpernadHasWeightAlways(Operation* op) {
if (op == nullptr)
return false;
Operation* source = nullptr;
do {
if (isa<spatial::SpatCompute>(*op)) {
return false;
}
else if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(*op)) {
auto tmpSource = extractSliceOp.getSource();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(*op)) {
auto tmpSource = expandShapeOp.getSrc();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto transposeOp = dyn_cast<ONNXTransposeOp>(*op)) {
auto tmpSource = transposeOp.getData();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(*op)) {
auto tmpSource = collapseShapeOp.getSrc();
auto definingOp = tmpSource.getDefiningOp();
if (definingOp)
op = definingOp;
else
return false;
}
else if (auto constantOp = dyn_cast<arith::ConstantOp>(*op)) {
source = constantOp;
}
else if (auto concatOp = dyn_cast<tensor::ConcatOp>(*op)) {
bool res = false;
for (auto operand : concatOp.getOperands()) {
res |= hasWeightAlways(operand.getDefiningOp());
if (res)
return res;
}
return res;
}
else {
op->dump();
llvm_unreachable("Global instruction not handle in func");
}
}
while (source == nullptr);
if (hasWeightAlways(source))
return true;
return false;
}
// TODO what we want to keep in global? // TODO what we want to keep in global?
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) { void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
Location loc = funcOp.getLoc(); Location loc = funcOp.getLoc();
@@ -271,8 +354,12 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
while (keep) { while (keep) {
keep = false; keep = false;
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) { for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
keep |= encapsulator<tensor::ExtractSliceOp>(
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); }); if (isa<spatial::SpatCompute>(instruction) || isa<func::ReturnOp>(instruction)
|| sourceOpernadHasWeightAlways(&instruction))
continue;
keep |= encapsulateSlice(rewriter, loc, &instruction);
keep |= encapsulator<tensor::ExpandShapeOp>( keep |= encapsulator<tensor::ExpandShapeOp>(
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); }); rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });

View File

@@ -5,6 +5,7 @@ add_public_tablegen_target(SpatialToPimIncGen)
add_pim_library(OMSpatialToPim add_pim_library(OMSpatialToPim
SpatialToPimPass.cpp SpatialToPimPass.cpp
Common.cpp Common.cpp
Patterns.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS

View File

@@ -0,0 +1,287 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::ExtractSliceOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override {
Location loc = extractSliceOp.getLoc();
if (!isa<func::FuncOp>(extractSliceOp->getParentOp()))
return failure();
for (auto& uses : extractSliceOp->getUses()) {
if (isa<spatial::SpatCompute>(uses.getOwner())) {
auto spatCompute = cast<spatial::SpatCompute>(uses.getOwner());
if (spatCompute.getInputs().empty())
return failure();
if (uses.getOperandNumber() < spatCompute.getInputs().getBeginOperandIndex())
return failure();
}
else if (isa_and_present<func::FuncOp>(uses.getOwner()->getParentOp())) {
return failure();
}
}
llvm::DenseMap<spatial::SpatCompute, Value> mapSpatToExtract;
for (auto& uses : llvm::make_early_inc_range(extractSliceOp->getUses())) {
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(uses.getOwner())) {
auto BBArgIndex = uses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
if (BBArgValue.use_empty())
continue;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatToExtract.contains(spatCompute)) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
mapSpatToExtract.insert({spatCompute, newExtractSlice->getResult(0)});
}
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(mapSpatToExtract[spatCompute]);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else {
{
auto spatCompute = uses.getOwner()->getParentOfType<spatial::SpatCompute>();
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatToExtract.contains(spatCompute)) {
auto newExtractSlice = rewriter.clone(*extractSliceOp.getOperation());
mapSpatToExtract.insert({spatCompute, newExtractSlice->getResult(0)});
}
rewriter.startOpModification(spatCompute.getOperation());
uses.set(mapSpatToExtract[spatCompute]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
}
}
rewriter.eraseOp(extractSliceOp);
return success();
}
};
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
static int i = 0;
Location loc = constantOp.getLoc();
if (hasWeightAlways(constantOp))
return failure();
if (!isa<func::FuncOp>(constantOp->getParentOp()))
return failure();
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
if (isa<spatial::SpatCompute>(op))
return false;
if (isa<func::FuncOp>(op->getParentOp()))
return true;
return false;
}))
return failure();
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
if (constRankedTensorType) {
mlir::MemRefType memRefType =
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
std::string argName = "const_" + std::to_string(i++);
memref::GlobalOp::create(rewriter,
loc,
rewriter.getStringAttr(argName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
constantOp.getValueAttr(),
rewriter.getUnitAttr(),
{});
llvm::DenseMap<spatial::SpatCompute, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute)) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute, toTensor.getResult()});
}
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(mapSpatComputeToConst[spatCompute]);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else {
{
auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>();
if (!spatCompute)
continue;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute)) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute, toTensor.getResult()});
}
rewriter.startOpModification(spatCompute.getOperation());
constUses.set(mapSpatComputeToConst[spatCompute]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
}
}
}
else if (constantOp.getType().isIntOrIndexOrFloat()) {
llvm::DenseMap<spatial::SpatCompute, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else {
auto parent = constUsers->getParentOfType<spatial::SpatCompute>();
assert(parent && "Global Constant used direcly not within a compute");
if (!mapSpatComputeToConst.contains(parent)) {
rewriter.setInsertionPoint(&parent.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
mapSpatComputeToConst.insert({parent, newConst->getResult(0)});
}
constUses.set(mapSpatComputeToConst[parent]);
}
}
}
auto parent = constantOp->getParentOp();
rewriter.eraseOp(constantOp);
return success();
}
};
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::func::FuncOp funcOp, PatternRewriter& rewriter) const override {
if (funcOp.getArguments().empty())
return failure();
if (llvm::all_of(funcOp.getArguments(),
[](mlir::BlockArgument blockArgument) { return blockArgument.use_empty(); }))
return failure();
Location loc = funcOp.getLoc();
for (auto [index, arg] : llvm::enumerate(funcOp.getArguments())) {
if (arg.getUses().empty())
continue;
rewriter.setInsertionPoint(funcOp.getOperation());
assert(isa<mlir::RankedTensorType>(arg.getType()));
auto argRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(arg.getType());
mlir::MemRefType memRefType =
mlir::MemRefType::get(argRankedTensorType.getShape(), argRankedTensorType.getElementType());
std::string argName = "arg_" + std::to_string(index);
memref::GlobalOp::create(rewriter,
loc,
rewriter.getStringAttr(argName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
{},
{},
{});
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
auto argUser = argUses.getOwner();
if (auto spatCompute = dyn_cast<spatial::SpatCompute>(argUser)) {
auto BBArgIndex = argUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
rewriter.startOpModification(spatCompute.getOperation());
BBArgValue.replaceAllUsesWith(toTensor);
spatCompute.getInputsMutable().erase(BBArgIndex);
spatCompute.getBody().front().eraseArgument(BBArgIndex);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else {
rewriter.setInsertionPoint(argUser);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
rewriter.startOpModification(argUser);
argUses.set(getGlobalOp);
rewriter.finalizeOpModification(argUser);
}
}
}
return success();
}
};
} // namespace
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) {
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
patterns.getContext());
}
} // namespace onnx_mlir

View File

@@ -0,0 +1,10 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateGlobalTensorToMemrefPatterns(mlir::RewritePatternSet& patterns);
}

View File

@@ -1,20 +1,26 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h" #include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_os_ostream.h" #include "llvm/Support/raw_os_ostream.h"
#include <cassert> #include <cassert>
@@ -23,6 +29,7 @@
#include <utility> #include <utility>
#include "Conversion/ONNXToSpatial/Common.hpp" #include "Conversion/ONNXToSpatial/Common.hpp"
#include "Patterns.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -51,7 +58,7 @@ struct SpatialToPimPass : PassWrapper<SpatialToPimPass, OperationPass<ModuleOp>>
void runOnOperation() final; void runOnOperation() final;
private: private:
SmallVector<Value> outputTensors; SmallVector<std::function<Value(IRRewriter& rewriter, Location loc)>> outputTensors;
size_t coreId = 0; size_t coreId = 0;
SmallVector<Operation*> operationsToRemove; SmallVector<Operation*> operationsToRemove;
@@ -146,12 +153,21 @@ void SpatialToPimPass::runOnOperation() {
scf::SCFDialect, scf::SCFDialect,
BuiltinDialect>(); BuiltinDialect>();
RewritePatternSet patterns(ctx); {
populateWithGenerated(patterns); RewritePatternSet patterns(ctx);
populateWithGenerated(patterns);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) { if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
signalPassFailure(); signalPassFailure();
return; return;
}
}
{
RewritePatternSet patterns(ctx);
populateGlobalTensorToMemrefPatterns(patterns);
walkAndApplyPatterns(moduleOp, std::move(patterns));
} }
auto entryFunc = getPimEntryFunc(moduleOp); auto entryFunc = getPimEntryFunc(moduleOp);
@@ -278,7 +294,7 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
auto storedType = cast<ShapedType>(storedValue.getType()); auto storedType = cast<ShapedType>(storedValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8; size_t elementSize = storedType.getElementTypeBitWidth() / 8;
Value outputTensor = outputTensors[resultIndexInReturn]; auto outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
if (auto storedOp = storedValue.getDefiningOp()) if (auto storedOp = storedValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp); rewriter.setInsertionPointAfter(storedOp);
PimMemCopyDevToHostOp::create(rewriter, PimMemCopyDevToHostOp::create(rewriter,
@@ -300,8 +316,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8; size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
// Store to global memory // Store to global memory
Value outputTensor = outputTensors[resultIndexInReturn];
rewriter.setInsertionPointAfterValue(yieldValue); rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
PimMemCopyDevToHostOp::create(rewriter, PimMemCopyDevToHostOp::create(rewriter,
loc, loc,
outputTensor.getType(), outputTensor.getType(),
@@ -341,8 +357,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter
size_t elementSize = yieldType.getElementTypeBitWidth() / 8; size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
// Store to global memory // Store to global memory
Value outputTensor = outputTensors[concatIndexInReturn];
rewriter.setInsertionPointAfterValue(yieldValue); rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = outputTensors[concatIndexInReturn](rewriter, loc);
PimMemCopyDevToHostOp::create(rewriter, PimMemCopyDevToHostOp::create(rewriter,
loc, loc,
outputTensor.getType(), outputTensor.getType(),
@@ -448,17 +464,35 @@ void SpatialToPimPass::enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, I
void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) { void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rewriter) {
outputTensors.reserve(returnOp->getNumOperands()); outputTensors.reserve(returnOp->getNumOperands());
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
rewriter.setInsertionPointToStart(returnOp->getBlock()); rewriter.setInsertionPointToStart(returnOp->getBlock());
for (auto returnValue : returnOp->getOperands()) {
Operation* returnValueDefiningOp = returnValue.getDefiningOp(); Operation* returnValueDefiningOp = returnValue.getDefiningOp();
if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) { if (returnValueDefiningOp->hasTrait<OpTrait::ConstantLike>()) {
assert(!hasWeightAlways(returnValueDefiningOp)); assert(!hasWeightAlways(returnValueDefiningOp));
outputTensors.push_back(returnValue); outputTensors.push_back( [returnValue] (IRRewriter& rewriter, Location loc) -> Value { return returnValue; });
} }
else { else {
auto newOutputTensor = auto outRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(returnValue.getType());
createEmptyTensorFromShaped(rewriter, returnValue.getLoc(), cast<ShapedType>(returnValue.getType())); mlir::MemRefType memRefType =
outputTensors.push_back(newOutputTensor); mlir::MemRefType::get(outRankedTensorType.getShape(), outRankedTensorType.getElementType());
std::string outputName = "output_" + std::to_string(index);
rewriter.setInsertionPoint(returnOp.getParentOp());
memref::GlobalOp::create(rewriter,
returnOp.getLoc(),
rewriter.getStringAttr(outputName),
rewriter.getStringAttr("private"),
TypeAttr::get(memRefType),
{},
{},
{});
outputTensors.push_back(
[memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, outRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
return toTensor.getResult();
});
} }
} }
} }
@@ -466,11 +500,11 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew
LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) { LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter) {
Location loc = funcOp.getLoc(); Location loc = funcOp.getLoc();
auto insertMemCopyHostToDev = [&](auto valueToReplace, auto hostTensor, int64_t elementsOffset) { auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) {
auto tensorType = cast<ShapedType>(valueToReplace.getType()); auto tensorType = cast<ShapedType>(inputTensor.getType());
Type elementType = tensorType.getElementType(); Type elementType = tensorType.getElementType();
size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPoint(getEarliestUserWithinBlock(valueToReplace)); rewriter.setInsertionPointAfter(inputTensor.getDefiningOp());
auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType);
@@ -479,85 +513,28 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
loc, loc,
tensorType, tensorType,
deviceTensor, deviceTensor,
hostTensor, inputTensor,
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)), rewriter.getI32IntegerAttr(static_cast<int32_t>(elementsOffset * elementByteSize)),
rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize))); rewriter.getI32IntegerAttr(static_cast<int32_t>(tensorType.getNumElements() * elementByteSize)));
rewriter.replaceAllUsesWith(valueToReplace, memCopyHostToDevOp.getResult()); rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
}; };
// Replace input tensors with memRefs
SmallVector<bufferization::ToTensorOp, 8> inputTensors;
for (size_t i = 0; i < funcOp.getNumArguments(); i++) {
BlockArgument tensorArg = funcOp.getArgument(i);
DictionaryAttr tensorArgAttrs = funcOp.getArgAttrDict(i);
ShapedType tensorArgType = cast<ShapedType>(tensorArg.getType());
MemRefType memRefArgType = MemRefType::get(tensorArgType.getShape(), tensorArgType.getElementType());
if (failed(funcOp.insertArgument(i + 1, memRefArgType, tensorArgAttrs, loc)))
return funcOp.emitError("failed to insert memref argument during Spatial-to-Pim lowering");
BlockArgument memRefArg = funcOp.getArgument(i + 1);
Block& block = funcOp.getBody().front();
rewriter.setInsertionPoint(&block.front());
auto toTensorOp =
bufferization::ToTensorOp::create(rewriter, loc, tensorArgType, memRefArg, rewriter.getUnitAttr());
inputTensors.push_back(toTensorOp);
tensorArg.replaceAllUsesWith(toTensorOp);
if (failed(funcOp.eraseArgument(i)))
return funcOp.emitError("failed to erase tensor argument during Spatial-to-Pim lowering");
}
llvm::SmallSet<tensor::ExtractSliceOp, 8> sliceOpsToRemove;
for (auto& op : funcOp.getBody().getOps()) for (auto& op : funcOp.getBody().getOps())
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) { if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
unsigned numComputeWeights = computeOp.getWeights().size(); assert(computeOp.getInputs().size() == 0 && "Already removed from mergeNode and global input handle");
for (auto [computeInputIdx, computeOpInput] : llvm::enumerate(computeOp.getInputs())) { assert(computeOp.getBody().front().getNumArguments() == 0
TypedValue<TensorType> tensorSource; && "Already removed from mergeNode and global input handle");
int64_t elementsOffset = 0; for (auto getGlobal : computeOp.getOps<memref::GetGlobalOp>()) {
if (getGlobal.getName().starts_with("arg")) {
if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(computeOpInput.getDefiningOp())) { assert(getGlobal->hasOneUse() && "global must have a single entry point in the compute");
tensorSource = cast<TypedValue<TensorType>>(sliceOp.getSource()); auto toTensorOpValue = *getGlobal->getUsers().begin()->getResults().begin();
insertMemCopyHostToDev(toTensorOpValue, 0);
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
continue;
ArrayRef<int64_t> sourceShape = tensorSource.getType().getShape();
ArrayRef<int64_t> sliceOffsets = sliceOp.getStaticOffsets();
ArrayRef<int64_t> sliceSizes = sliceOp.getStaticSizes();
ArrayRef<int64_t> sliceStrides = sliceOp.getStaticStrides();
assert("Extracting slice non-contiguous in memory"
&& isMemoryContiguous(sourceShape, sliceOffsets, sliceSizes, sliceStrides));
for (size_t i = 0; i < sliceOffsets.size(); i++) {
int64_t partialOffset = sliceOffsets[i];
if (partialOffset != 0)
for (size_t j = i + 1; j < sourceShape.size(); j++)
partialOffset *= sourceShape[j];
elementsOffset += partialOffset;
}
computeOp.setOperand(numComputeWeights + computeInputIdx, tensorSource);
sliceOpsToRemove.insert(sliceOp);
} }
else
tensorSource = cast<TypedValue<TensorType>>(computeOpInput);
// Compute results must be transferred through channels via send/receive
if (isa<spatial::SpatCompute>(tensorSource.getDefiningOp()))
continue;
BlockArgument computeBlockArgToReplace = computeOp.getBody().front().getArgument(computeInputIdx);
insertMemCopyHostToDev(computeBlockArgToReplace, tensorSource, elementsOffset);
} }
} }
for (auto sliceOp : sliceOpsToRemove)
if (sliceOp->getUses().empty())
rewriter.eraseOp(sliceOp);
return success(); return success();
} }
@@ -735,12 +712,13 @@ void SpatialToPimPass::lowerBroadcastChannelOps(func::FuncOp funcOp, IRRewriter&
void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) { void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewriter& rewriter) {
SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end()); SmallVector<Value> originalOperands(returnOp.getOperands().begin(), returnOp.getOperands().end());
auto loc = returnOp.getLoc();
for (auto it : llvm::enumerate(originalOperands)) { for (auto it : llvm::enumerate(originalOperands)) {
size_t orderWithinReturn = it.index(); size_t orderWithinReturn = it.index();
Operation* returnOperand = it.value().getDefiningOp(); Operation* returnOperand = it.value().getDefiningOp();
rewriter.setInsertionPoint(returnOp);
rewriter.modifyOpInPlace(returnOp, rewriter.modifyOpInPlace(returnOp,
[&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn]); }); [&] { returnOp.setOperand(orderWithinReturn, outputTensors[orderWithinReturn](rewriter, loc)); });
Operation* opToErase = returnOperand; Operation* opToErase = returnOperand;
while (opToErase) { while (opToErase) {

View File

@@ -24,7 +24,7 @@ def PimTensor :
// Execution // Execution
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def PimCoreOp : PimOp<"core", [SingleBlock]> { def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> {
let summary = "Execute a block on a PIM core"; let summary = "Execute a block on a PIM core";
let regions = (region SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$body);

View File

@@ -3,12 +3,17 @@
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Threading.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "Compiler/PimCodeGen.hpp" #include "Compiler/PimCodeGen.hpp"
#include "Dialect/Pim/PimOps.hpp" #include "Dialect/Pim/PimOps.hpp"
#include "Dialect/Pim/Transforms/Bufferization/Common.hpp" #include "Dialect/Pim/Transforms/Bufferization/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Pass/PIMPasses.h" #include "src/Accelerators/PIM/Pass/PIMPasses.h"
#include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerOptions.hpp"
@@ -40,14 +45,44 @@ private:
void PimBufferizationPass::runOnOperation() { void PimBufferizationPass::runOnOperation() {
auto moduleOp = getOperation(); auto moduleOp = getOperation();
// Refactor this into a function
{
auto funcOp = getPimEntryFunc(moduleOp);
// One-Shot-Bufferization auto coreOps = llvm::to_vector(funcOp->getOps<pim::PimCoreOp>());
bufferization::OneShotBufferizationOptions options; MLIRContext* ctx = moduleOp.getContext();
options.allowUnknownOps = true; // failableParallelForEach will run the lambda in parallel and stop if any thread fails
bufferization::BufferizationState state; LogicalResult result = mlir::failableParallelForEach(ctx, coreOps, [&](pim::PimCoreOp coreOp) {
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) { // Again, allocate state LOCALLY per thread/function
moduleOp.emitError("Failed to bufferize PIM and Spatial ops"); bufferization::OneShotBufferizationOptions options;
signalPassFailure(); options.allowUnknownOps = true;
bufferization::BufferizationState state;
if (failed(bufferization::runOneShotBufferize(coreOp, options, state))) {
coreOp.emitError("Failed to bufferize PIM and Spatial ops");
return failure();
}
return success();
});
if (failed(result)) {
moduleOp.emitError("Failed to bufferize-parallel PIM and Spatial ops");
signalPassFailure();
}
funcOp->walk([&](bufferization::ToTensorOp toTensorOp) {
if (llvm::isa_and_present<pim::PimCoreOp>(toTensorOp->getParentOp()))
toTensorOp->setAttr("restrict", UnitAttr::get(ctx));
});
// One-Shot-Bufferization
bufferization::OneShotBufferizationOptions options;
options.allowUnknownOps = true;
bufferization::BufferizationState state;
if (failed(bufferization::runOneShotBufferize(moduleOp, options, state))) {
moduleOp.emitError("Failed to bufferize PIM and Spatial ops");
signalPassFailure();
}
} }
MLIRContext* ctx = moduleOp.getContext(); MLIRContext* ctx = moduleOp.getContext();

View File

@@ -735,6 +735,26 @@ public:
LogicalResult initialize(MLIRContext* context) override { return success(); } LogicalResult initialize(MLIRContext* context) override { return success(); }
void verifyOrderAssumption(std::vector<spatial::SpatCompute>& dominanceOrderCompute) {
uint64_t computeNumber = 0;
llvm::DenseSet<SpatCompute> visited;
mlir::func::FuncOp funcOp = getOperation();
for (auto spatCompute : funcOp.getOps<SpatCompute>())
computeNumber++;
assert(computeNumber == dominanceOrderCompute.size());
for(auto domCompute : dominanceOrderCompute){
visited.insert(domCompute);
for(auto domInput : domCompute.getInputs() ){
if(auto domImputAsCompute = dyn_cast_if_present<SpatCompute>(domInput.getDefiningOp())){
assert(visited.contains(domImputAsCompute) && "Dominance order violated\n");
}
}
}
}
void runOnOperation() override { void runOnOperation() override {
mergeTriviallyConnectedComputes(getOperation()); mergeTriviallyConnectedComputes(getOperation());
packWideWeightedVmmBands(getOperation()); packWideWeightedVmmBands(getOperation());
@@ -744,6 +764,9 @@ public:
auto& lastComputeOfCpu = analysisResult.isLastComputeOfCpu; auto& lastComputeOfCpu = analysisResult.isLastComputeOfCpu;
auto& cpuToLastComputeMap = analysisResult.cpuToLastComputeMap; auto& cpuToLastComputeMap = analysisResult.cpuToLastComputeMap;
func::FuncOp func = getOperation();
verifyOrderAssumption(analysisResult.dominanceOrderCompute);
for (auto currentComputeNode : analysisResult.dominanceOrderCompute) { for (auto currentComputeNode : analysisResult.dominanceOrderCompute) {
size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode); size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode);
if (!cpuToNewComputeMap.contains(cpu)) { if (!cpuToNewComputeMap.contains(cpu)) {
@@ -765,11 +788,19 @@ public:
} }
for (auto computeNodeToRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) { for (auto computeNodeToRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) {
for (auto users : computeNodeToRemove->getUsers()) if (!computeNodeToRemove->use_empty()) {
llvm::dbgs() << "Full module\n";
computeNodeToRemove->getParentOfType<ModuleOp>()->dump();
llvm::dbgs() << "Compute with uses:\n";
computeNodeToRemove.dump();
}
for (auto users : computeNodeToRemove->getUsers()) {
llvm::dbgs() << "Users:\n";
users->dump(); users->dump();
}
computeNodeToRemove.erase(); computeNodeToRemove.erase();
} }
func::FuncOp func = getOperation();
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged"); dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
generateReport(func, "spatial1_dcp_merged_report"); generateReport(func, "spatial1_dcp_merged_report");
} }

View File

@@ -116,10 +116,9 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern<linalg::MapOp> {
auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill"); auto globalOp = createFoldedGlobal(moduleOp, mapOp.getLoc(), initType, splatAttr, "pim_core_fill");
OpBuilder::InsertionGuard guard(rewriter); OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(coreOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
rewriter.setInsertionPoint(mapOp); rewriter.setInsertionPoint(mapOp);
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName());
auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8; auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8;
pim::PimMemCopyOp::create(rewriter, pim::PimMemCopyOp::create(rewriter,
mapOp.getLoc(), mapOp.getLoc(),