Translate global constant to symble
This commit is contained in:
@@ -3,9 +3,11 @@
|
|||||||
#include "mlir/IR/Attributes.h"
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/BuiltinAttributes.h"
|
#include "mlir/IR/BuiltinAttributes.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/SmallPtrSet.h"
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
|
#include "llvm/ADT/StringExtras.h"
|
||||||
#include "llvm/Support/FileSystem.h"
|
#include "llvm/Support/FileSystem.h"
|
||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
#include "llvm/Support/raw_ostream.h"
|
#include "llvm/Support/raw_ostream.h"
|
||||||
@@ -53,9 +55,23 @@ void PimMemory::allocateMemoryForValue(mlir::Value value, MemEntry& memEntry) {
|
|||||||
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
||||||
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
|
SmallDenseMap<memref::GlobalOp, mlir::Value, 8> globalConstants;
|
||||||
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
|
SmallVector<std::pair<mlir::Value, mlir::Value>, 16> globalAliases;
|
||||||
|
SmallVector<mlir::Value> args;
|
||||||
|
|
||||||
|
|
||||||
|
for (mlir::Value arg : funcOp.getArguments()){
|
||||||
|
gatherMemEntry(arg);
|
||||||
|
args.push_back(arg);
|
||||||
|
}
|
||||||
|
|
||||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||||
if (!hasWeightAlways(getGlobalOp)) {
|
if (!hasWeightAlways(getGlobalOp)) {
|
||||||
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
auto globalMemrefOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||||
|
if (globalMemrefOp.getName().starts_with("arg")){
|
||||||
|
StringRef indexStr = globalMemrefOp.getName().substr(4);
|
||||||
|
int index = 0;
|
||||||
|
llvm::to_integer(indexStr,index, 10);
|
||||||
|
globalAliases.push_back({getGlobalOp.getResult(), args[index]});
|
||||||
|
}
|
||||||
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
|
auto [iter, inserted] = globalConstants.try_emplace(globalMemrefOp, getGlobalOp.getResult());
|
||||||
if (inserted)
|
if (inserted)
|
||||||
gatherMemEntry(getGlobalOp.getResult());
|
gatherMemEntry(getGlobalOp.getResult());
|
||||||
@@ -64,8 +80,6 @@ void PimMemory::allocateHost(ModuleOp moduleOp, func::FuncOp funcOp) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
for (mlir::Value arg : funcOp.getArguments())
|
|
||||||
gatherMemEntry(arg);
|
|
||||||
|
|
||||||
funcOp.walk([&](memref::AllocOp allocOp) {
|
funcOp.walk([&](memref::AllocOp allocOp) {
|
||||||
if (!allocOp->getParentOfType<pim::PimCoreOp>())
|
if (!allocOp->getParentOfType<pim::PimCoreOp>())
|
||||||
|
|||||||
@@ -148,6 +148,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
|||||||
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
||||||
|
|
||||||
annotateWeightsConstants(*entryFunc);
|
annotateWeightsConstants(*entryFunc);
|
||||||
|
|
||||||
encapsulateGlobalInstruction(*entryFunc);
|
encapsulateGlobalInstruction(*entryFunc);
|
||||||
|
|
||||||
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
|
if (failed(promoteConstantInputsToWeights(*entryFunc))) {
|
||||||
@@ -166,19 +167,46 @@ bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::func
|
|||||||
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
||||||
Value source = funcSource(toRemoveOp);
|
Value source = funcSource(toRemoveOp);
|
||||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||||
if (isa_and_present<spatial::SpatCompute>(source.getDefiningOp())) {
|
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
rewriter.setInsertionPointToEnd(BB);
|
||||||
rewriter.setInsertionPointToEnd(BB);
|
IRMapping mapper;
|
||||||
IRMapping mapper;
|
mapper.map(source, BB->getArgument(0));
|
||||||
mapper.map(source, BB->getArgument(0));
|
auto newInst = rewriter.clone(*inst, mapper);
|
||||||
auto newInst = rewriter.clone(*inst, mapper);
|
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
||||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
inst->replaceAllUsesWith(newCompute->getResults());
|
||||||
inst->replaceAllUsesWith(newCompute->getResults());
|
inst->erase();
|
||||||
inst->erase();
|
return true;
|
||||||
return true;
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool encapsulateSlice(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||||
|
if (tensor::ExtractSliceOp toRemoveOp = llvm::dyn_cast_if_present<tensor::ExtractSliceOp>(inst)) {
|
||||||
|
for (auto& use : toRemoveOp->getUses()) {
|
||||||
|
auto users = use.getOwner();
|
||||||
|
if (auto spatCompUser = dyn_cast<spatial::SpatCompute>(users)) {
|
||||||
|
unsigned int poistionUses = use.getOperandNumber();
|
||||||
|
if (poistionUses < spatCompUser.getInputs().getBeginOperandIndex())
|
||||||
|
return false;
|
||||||
|
}else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
auto source = toRemoveOp.getSource();
|
||||||
|
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||||
|
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), source);
|
||||||
|
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||||
|
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||||
|
rewriter.setInsertionPointToEnd(BB);
|
||||||
|
IRMapping mapper;
|
||||||
|
mapper.map(source, BB->getArgument(0));
|
||||||
|
auto newInst = rewriter.clone(*inst, mapper);
|
||||||
|
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResults());
|
||||||
|
inst->replaceAllUsesWith(newCompute->getResults());
|
||||||
|
inst->erase();
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -187,8 +215,8 @@ bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
|||||||
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
|
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
|
||||||
auto sources = toRemoveOp.getInputs();
|
auto sources = toRemoveOp.getInputs();
|
||||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||||
if (llvm::any_of(
|
if (llvm::any_of(sources,
|
||||||
sources, [](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
|
[](auto source) { return isa_and_present<spatial::SpatCompute>(source.getDefiningOp()); })) {
|
||||||
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
auto newCompute = spatial::SpatCompute::create(rewriter, loc, inst->getResultTypes(), sources);
|
||||||
SmallVector<Type> sourceTypes;
|
SmallVector<Type> sourceTypes;
|
||||||
SmallVector<Location> sourceLoc;
|
SmallVector<Location> sourceLoc;
|
||||||
@@ -277,8 +305,7 @@ void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
|||||||
while (keep) {
|
while (keep) {
|
||||||
keep = false;
|
keep = false;
|
||||||
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
||||||
keep |= encapsulator<tensor::ExtractSliceOp>(
|
keep |= encapsulateSlice(rewriter, loc, &instruction);
|
||||||
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); });
|
|
||||||
|
|
||||||
keep |= encapsulator<tensor::ExpandShapeOp>(
|
keep |= encapsulator<tensor::ExpandShapeOp>(
|
||||||
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
|
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
|
||||||
@@ -324,8 +351,7 @@ void ONNXToSpatialPass::mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
|||||||
|
|
||||||
rewriter.setInsertionPointAfter(compute.getOperation());
|
rewriter.setInsertionPointAfter(compute.getOperation());
|
||||||
|
|
||||||
auto newCompute =
|
auto newCompute = spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
||||||
spatial::SpatCompute::create(rewriter, loc, child.getResultTypes(), compute.getOperands());
|
|
||||||
newCompute.getProperties().setOperandSegmentSizes(
|
newCompute.getProperties().setOperandSegmentSizes(
|
||||||
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
{static_cast<int>(compute.getWeights().size()), static_cast<int>(compute.getInputs().size())});
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,19 @@
|
|||||||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||||
|
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||||
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
#include "mlir/IR/PatternMatch.h"
|
||||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
#include "mlir/IR/Value.h"
|
||||||
|
|
||||||
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/Support/Casting.h"
|
||||||
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
|
#include "llvm/Support/LogicalResult.h"
|
||||||
|
|
||||||
|
#include "Common/PimCommon.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
@@ -13,6 +21,96 @@ using namespace mlir;
|
|||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
|
||||||
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
|
||||||
|
static int i = 0;
|
||||||
|
Location loc = constantOp.getLoc();
|
||||||
|
|
||||||
|
if (hasWeightAlways(constantOp))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (!isa<func::FuncOp>(constantOp->getParentOp()))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
|
||||||
|
|
||||||
|
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
|
||||||
|
|
||||||
|
if (constRankedTensorType) {
|
||||||
|
mlir::MemRefType memRefType =
|
||||||
|
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
|
||||||
|
std::string argName = "const_" + std::to_string(i++);
|
||||||
|
memref::GlobalOp::create(rewriter,
|
||||||
|
loc,
|
||||||
|
rewriter.getStringAttr(argName),
|
||||||
|
rewriter.getStringAttr("private"),
|
||||||
|
TypeAttr::get(memRefType),
|
||||||
|
constantOp.getValueAttr(),
|
||||||
|
rewriter.getUnitAttr(),
|
||||||
|
{});
|
||||||
|
|
||||||
|
for (auto& constUses : constantOp->getUses()) {
|
||||||
|
auto constUsers = constUses.getOwner();
|
||||||
|
|
||||||
|
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||||
|
|
||||||
|
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
||||||
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatCompute.getOperation());
|
||||||
|
BBArgValue.replaceAllUsesWith(toTensor);
|
||||||
|
spatCompute.getInputsMutable().erase(BBArgIndex);
|
||||||
|
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
||||||
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
llvm_unreachable("Who are using const globally");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (constantOp.getType().isIntOrIndexOrFloat()) {
|
||||||
|
llvm::DenseMap<spatial::SpatCompute, Value> mapSpatComputeToConst;
|
||||||
|
|
||||||
|
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||||
|
auto constUsers = constUses.getOwner();
|
||||||
|
|
||||||
|
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||||
|
|
||||||
|
auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex();
|
||||||
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||||
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
|
auto newConst = rewriter.clone(*constantOp);
|
||||||
|
|
||||||
|
rewriter.startOpModification(spatCompute.getOperation());
|
||||||
|
BBArgValue.replaceAllUsesWith(newConst->getResult(0));
|
||||||
|
spatCompute.getInputsMutable().erase(BBArgIndex);
|
||||||
|
spatCompute.getBody().front().eraseArgument(BBArgIndex);
|
||||||
|
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto parent = constUsers->getParentOfType<spatial::SpatCompute>();
|
||||||
|
assert(parent && "Global Constant used direcly not within a compute");
|
||||||
|
if (!mapSpatComputeToConst.contains(parent)) {
|
||||||
|
rewriter.setInsertionPoint(&parent.getBody().front().front());
|
||||||
|
auto newConst = rewriter.clone(*constantOp);
|
||||||
|
mapSpatComputeToConst.insert({parent, newConst->getResult(0)});
|
||||||
|
}
|
||||||
|
constUses.set(mapSpatComputeToConst[parent]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto parent = constantOp->getParentOp();
|
||||||
|
rewriter.eraseOp(constantOp);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
|
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
@@ -42,13 +140,13 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
|||||||
std::string argName = "arg_" + std::to_string(index);
|
std::string argName = "arg_" + std::to_string(index);
|
||||||
|
|
||||||
memref::GlobalOp::create(rewriter,
|
memref::GlobalOp::create(rewriter,
|
||||||
loc,
|
loc,
|
||||||
rewriter.getStringAttr(argName),
|
rewriter.getStringAttr(argName),
|
||||||
rewriter.getStringAttr("private"),
|
rewriter.getStringAttr("private"),
|
||||||
TypeAttr::get(memRefType),
|
TypeAttr::get(memRefType),
|
||||||
{},
|
{},
|
||||||
{},
|
{},
|
||||||
{});
|
{});
|
||||||
|
|
||||||
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
for (auto& argUses : llvm::make_early_inc_range(arg.getUses())) {
|
||||||
auto argUser = argUses.getOwner();
|
auto argUser = argUses.getOwner();
|
||||||
@@ -57,8 +155,8 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
|||||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||||
auto toTensor = bufferization::ToTensorOp::create(rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
auto toTensor = bufferization::ToTensorOp::create(
|
||||||
|
rewriter, loc, argRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||||
|
|
||||||
rewriter.startOpModification(spatCompute.getOperation());
|
rewriter.startOpModification(spatCompute.getOperation());
|
||||||
BBArgValue.replaceAllUsesWith(toTensor);
|
BBArgValue.replaceAllUsesWith(toTensor);
|
||||||
@@ -82,7 +180,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) {
|
void populateGlobalTensorToMemrefPatterns(RewritePatternSet& patterns) {
|
||||||
patterns.add<FuncOpArgToGlobalMemoryPattern>(patterns.getContext());
|
patterns.add<FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(patterns.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||||
|
|
||||||
#include "llvm/ADT/SmallSet.h"
|
#include "llvm/ADT/SmallSet.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
@@ -165,10 +166,7 @@ void SpatialToPimPass::runOnOperation() {
|
|||||||
RewritePatternSet patterns(ctx);
|
RewritePatternSet patterns(ctx);
|
||||||
populateGlobalTensorToMemrefPatterns(patterns);
|
populateGlobalTensorToMemrefPatterns(patterns);
|
||||||
|
|
||||||
if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) {
|
walkAndApplyPatterns(moduleOp, std::move(patterns));
|
||||||
signalPassFailure();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto entryFunc = getPimEntryFunc(moduleOp);
|
auto entryFunc = getPimEntryFunc(moduleOp);
|
||||||
@@ -504,6 +502,7 @@ LogicalResult SpatialToPimPass::allocateAndInitializeCoreLocalVariables(func::Fu
|
|||||||
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
|
rewriter.replaceAllUsesExcept(inputTensor, memCopyHostToDevOp.getResult(), {memCopyHostToDevOp});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
for (auto& op : funcOp.getBody().getOps())
|
for (auto& op : funcOp.getBody().getOps())
|
||||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||||
assert(computeOp.getInputs().size() == 0 && "Already removed from mergeNode and global input handle");
|
assert(computeOp.getInputs().size() == 0 && "Already removed from mergeNode and global input handle");
|
||||||
|
|||||||
@@ -95,23 +95,22 @@ void generateReport(func::FuncOp funcOp, const std::string& name) {
|
|||||||
auto expectedPrintedValue = currentComputeId + 1;
|
auto expectedPrintedValue = currentComputeId + 1;
|
||||||
bool rangePrinted = false;
|
bool rangePrinted = false;
|
||||||
cI++;
|
cI++;
|
||||||
for (; cI <= lastIndex; ++cI){
|
for (; cI <= lastIndex; ++cI) {
|
||||||
auto candidateToPrint = std::get<0>(collectedData[cI]);
|
auto candidateToPrint = std::get<0>(collectedData[cI]);
|
||||||
if (candidateToPrint == expectedPrintedValue){
|
if (candidateToPrint == expectedPrintedValue) {
|
||||||
expectedPrintedValue = candidateToPrint + 1;
|
expectedPrintedValue = candidateToPrint + 1;
|
||||||
rangePrinted = true;
|
rangePrinted = true;
|
||||||
} else {
|
}
|
||||||
if (rangePrinted) {
|
else {
|
||||||
|
if (rangePrinted)
|
||||||
os << " - " << expectedPrintedValue - 1;
|
os << " - " << expectedPrintedValue - 1;
|
||||||
}
|
|
||||||
os << " , " << candidateToPrint;
|
os << " , " << candidateToPrint;
|
||||||
rangePrinted = false;
|
rangePrinted = false;
|
||||||
expectedPrintedValue = candidateToPrint + 1;
|
expectedPrintedValue = candidateToPrint + 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (rangePrinted && currentComputeId != expectedPrintedValue - 1){
|
if (rangePrinted && currentComputeId != expectedPrintedValue - 1)
|
||||||
os << " - " << expectedPrintedValue - 1;
|
os << " - " << expectedPrintedValue - 1;
|
||||||
}
|
|
||||||
|
|
||||||
os << " :\n";
|
os << " :\n";
|
||||||
os << "\tNumber of instructions " << currentNumInst << "\n";
|
os << "\tNumber of instructions " << currentNumInst << "\n";
|
||||||
@@ -193,11 +192,34 @@ public:
|
|||||||
|
|
||||||
LogicalResult initialize(MLIRContext* context) override { return success(); }
|
LogicalResult initialize(MLIRContext* context) override { return success(); }
|
||||||
|
|
||||||
|
void verifyOrderAssumption(std::vector<spatial::SpatCompute>& dominanceOrderCompute) {
|
||||||
|
uint64_t computeNumber = 0;
|
||||||
|
llvm::DenseSet<SpatCompute> visited;
|
||||||
|
mlir::func::FuncOp funcOp = getOperation();
|
||||||
|
for (auto spatCompute : funcOp.getOps<SpatCompute>())
|
||||||
|
computeNumber++;
|
||||||
|
|
||||||
|
assert(computeNumber == dominanceOrderCompute.size());
|
||||||
|
|
||||||
|
for(auto domCompute : dominanceOrderCompute){
|
||||||
|
visited.insert(domCompute);
|
||||||
|
for(auto domInput : domCompute.getInputs() ){
|
||||||
|
if(auto domImputAsCompute = dyn_cast_if_present<SpatCompute>(domInput.getDefiningOp())){
|
||||||
|
assert(visited.contains(domImputAsCompute) && "Dominance order violated\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
DCPAnalysisResult& analysisResult = getAnalysis<spatial::DCPAnalysis>().getResult();
|
DCPAnalysisResult& analysisResult = getAnalysis<spatial::DCPAnalysis>().getResult();
|
||||||
auto& lastComputeOfCpu = analysisResult.isLastComputeOfCpu;
|
auto& lastComputeOfCpu = analysisResult.isLastComputeOfCpu;
|
||||||
auto& cpuToLastComputeMap = analysisResult.cpuToLastComputeMap;
|
auto& cpuToLastComputeMap = analysisResult.cpuToLastComputeMap;
|
||||||
|
|
||||||
|
func::FuncOp func = getOperation();
|
||||||
|
verifyOrderAssumption(analysisResult.dominanceOrderCompute);
|
||||||
|
|
||||||
for (auto currentComputeNode : analysisResult.dominanceOrderCompute) {
|
for (auto currentComputeNode : analysisResult.dominanceOrderCompute) {
|
||||||
size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode);
|
size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode);
|
||||||
if (!cpuToNewComputeMap.contains(cpu)) {
|
if (!cpuToNewComputeMap.contains(cpu)) {
|
||||||
@@ -219,11 +241,19 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (auto computeNodeToRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) {
|
for (auto computeNodeToRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) {
|
||||||
for (auto users : computeNodeToRemove->getUsers())
|
if (!computeNodeToRemove->use_empty()) {
|
||||||
|
llvm::dbgs() << "Full module\n";
|
||||||
|
computeNodeToRemove->getParentOfType<ModuleOp>()->dump();
|
||||||
|
|
||||||
|
llvm::dbgs() << "Compute with uses:\n";
|
||||||
|
computeNodeToRemove.dump();
|
||||||
|
}
|
||||||
|
for (auto users : computeNodeToRemove->getUsers()) {
|
||||||
|
llvm::dbgs() << "Users:\n";
|
||||||
users->dump();
|
users->dump();
|
||||||
|
}
|
||||||
computeNodeToRemove.erase();
|
computeNodeToRemove.erase();
|
||||||
}
|
}
|
||||||
func::FuncOp func = getOperation();
|
|
||||||
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
|
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
|
||||||
generateReport(func, "spatial1_dcp_merged_report");
|
generateReport(func, "spatial1_dcp_merged_report");
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user