MergeDCP pass all test
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h55m12s
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h55m12s
This commit is contained in:
@@ -5,11 +5,15 @@
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
@@ -42,6 +46,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
|
||||
|
||||
private:
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) const;
|
||||
void encapsulateGlobalInstruction(func::FuncOp funcOp);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -126,11 +131,88 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
|
||||
|
||||
annotateWeightsConstants(*entryFunc);
|
||||
encapsulateGlobalInstruction(*entryFunc);
|
||||
|
||||
// Dump to file for debug
|
||||
dumpModule(moduleOp, "spatial");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function<Value(T)> funcSource) {
|
||||
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
|
||||
Value source = funcSource(toRemoveOp);
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
if (isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp())) {
|
||||
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), source);
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
mapper.map(source, BB->getArgument(0));
|
||||
auto newInst = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResult(0));
|
||||
inst->replaceAllUsesWith(newCompute);
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
|
||||
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
|
||||
auto sources = toRemoveOp.getInputs();
|
||||
rewriter.setInsertionPointAfter(toRemoveOp);
|
||||
if (llvm::any_of(
|
||||
sources, [](auto source) { return isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp()); })) {
|
||||
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources);
|
||||
llvm::SmallVector<Type> sourceTypes;
|
||||
llvm::SmallVector<Location> sourceLoc;
|
||||
for (auto source : sources){
|
||||
sourceTypes.push_back(source.getType());
|
||||
sourceLoc.push_back(loc);
|
||||
}
|
||||
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
|
||||
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
|
||||
rewriter.setInsertionPointToEnd(BB);
|
||||
IRMapping mapper;
|
||||
for(auto [source,bbArg] : llvm::zip(sources, BB->getArguments()))
|
||||
mapper.map(source, bbArg);
|
||||
auto newConcat = rewriter.clone(*inst, mapper);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0));
|
||||
inst->replaceAllUsesWith(newCompute);
|
||||
inst->erase();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO what we want to keep in global?
|
||||
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
|
||||
Location loc = funcOp.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
bool keep = true;
|
||||
while (keep) {
|
||||
keep = false;
|
||||
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
|
||||
keep |= encapsulator<tensor::ExtractSliceOp>(
|
||||
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); });
|
||||
|
||||
keep |= encapsulator<tensor::ExpandShapeOp>(
|
||||
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
|
||||
|
||||
keep |= encapsulator<ONNXTransposeOp>(
|
||||
rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); });
|
||||
|
||||
keep |= encapsulator<tensor::CollapseShapeOp>(
|
||||
rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); });
|
||||
|
||||
keep |= encapsulateConcat(rewriter, loc, &instruction);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
bool isAlwaysWeight =
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
@@ -17,6 +18,19 @@ namespace spatial {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
SpatWeightedCompute getOriginalSpatWeightCompute(Operation* op) {
|
||||
if (!op)
|
||||
return {};
|
||||
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
op = extract.getSource().getDefiningOp();
|
||||
if (!op)
|
||||
return {};
|
||||
}
|
||||
if (auto res = llvm::dyn_cast<SpatWeightedCompute>(op))
|
||||
return res;
|
||||
return {};
|
||||
}
|
||||
|
||||
DCPAnalysisResult DCPAnalysis::runAnalysis() {
|
||||
using EdgesIndex = std::tuple<int64_t, int64_t, int64_t>;
|
||||
llvm::SmallVector<SpatWeightedCompute, 10> spatWeightedComputes;
|
||||
@@ -27,8 +41,7 @@ DCPAnalysisResult DCPAnalysis::runAnalysis() {
|
||||
|
||||
for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) {
|
||||
for (Value input : spatWeightedCompute.getInputs()) {
|
||||
if (auto spatWeightedComputeArgOp = llvm::dyn_cast_if_present<SpatWeightedCompute>(input.getDefiningOp());
|
||||
spatWeightedComputeArgOp) {
|
||||
if (auto spatWeightedComputeArgOp = getOriginalSpatWeightCompute(input.getDefiningOp())) {
|
||||
auto elemIter = llvm::find(spatWeightedComputes, spatWeightedComputeArgOp);
|
||||
assert(elemIter != spatWeightedComputes.end());
|
||||
auto indexStartEdge = std::distance(spatWeightedComputes.begin(), elemIter);
|
||||
|
||||
@@ -11,9 +11,12 @@
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
@@ -62,15 +65,12 @@ public:
|
||||
auto [first, second] = channelNewInserter();
|
||||
channelNewOpVal = first;
|
||||
channelSendInserter = second;
|
||||
auto op = computeResults.innerValue.getDefiningOp();
|
||||
if (op) {
|
||||
insertPointSend = InsertPoint(op->getBlock(), ++Block::iterator(op));
|
||||
}
|
||||
else {
|
||||
auto BB = computeResults.innerValue.getParentBlock();
|
||||
insertPointSend = InsertPoint(BB, BB->begin());
|
||||
}
|
||||
}
|
||||
if (!BB->empty() && isa<spatial::SpatYieldOp>(BB->back()))
|
||||
insertPointSend = InsertPoint(BB, --BB->end());
|
||||
else
|
||||
insertPointSend = InsertPoint(BB, BB->end());
|
||||
if (spatWeightedCompute) {
|
||||
for (auto& BB : spatWeightedCompute.getBody())
|
||||
if (&BB == insertPointSend.getBlock())
|
||||
@@ -129,8 +129,10 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute)))
|
||||
for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) {
|
||||
for (auto users : computeNodetoRemove->getUsers()) users->dump();
|
||||
computeNodetoRemove.erase();
|
||||
}
|
||||
func::FuncOp func = getOperation();
|
||||
dumpModule(cast<ModuleOp>(func->getParentOp()), "SpatialDCPMerged");
|
||||
}
|
||||
@@ -154,36 +156,40 @@ private:
|
||||
newComputeOperand.push_back(arg);
|
||||
|
||||
for (auto arg : oldWeightedCompute.getInputs())
|
||||
if (!llvm::isa<SpatWeightedCompute>(arg.getDefiningOp())) {
|
||||
if (!llvm::isa_and_present<SpatWeightedCompute>(arg.getDefiningOp())) {
|
||||
newComputeOperand.push_back(arg);
|
||||
newBBOperandType.push_back(arg.getType());
|
||||
newBBLocations.push_back(loc);
|
||||
}
|
||||
|
||||
auto newWeightedCompute = SpatWeightedCompute::create(rewriter, loc, newWeightedComputeType, newComputeOperand);
|
||||
|
||||
rewriter.createBlock(
|
||||
&newWeightedCompute.getBody(), newWeightedCompute.getBody().end(), newBBOperandType, newBBLocations);
|
||||
newWeightedCompute.getProperties().setOperandSegmentSizes(
|
||||
{(int) oldWeightedCompute.getWeights().size(), (int) newBBOperandType.size()});
|
||||
rewriter.setInsertionPointToEnd(&newWeightedCompute.getBody().front());
|
||||
|
||||
auto& newBB = newWeightedCompute.getBody().front();
|
||||
auto& oldBB = oldWeightedCompute.getBody().front();
|
||||
rewriter.setInsertionPointToEnd(&newBB);
|
||||
|
||||
int indexNew = 0;
|
||||
int indexOld = oldWeightedCompute.getWeights().size();
|
||||
int indexOldStart = oldWeightedCompute.getWeights().size();
|
||||
size_t indexOld = oldWeightedCompute.getWeights().size();
|
||||
size_t indexOldStart = oldWeightedCompute.getWeights().size();
|
||||
for (; indexOld < oldWeightedCompute.getNumOperands(); ++indexOld) {
|
||||
if (!llvm::isa<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp())) {
|
||||
mapper.map(oldWeightedCompute.getBody().front().getArgument(indexOld - indexOldStart),
|
||||
newWeightedCompute.getBody().front().getArgument(indexNew++));
|
||||
if (!llvm::isa_and_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp())) {
|
||||
mapper.map(oldBB.getArgument(indexOld - indexOldStart), newBB.getArgument(indexNew++));
|
||||
}
|
||||
else {
|
||||
auto argWeightCompute =
|
||||
llvm::dyn_cast_if_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp());
|
||||
|
||||
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
|
||||
auto [channelVal, _] = lazyArgWeight.getAsChannelValueAndInsertSender();
|
||||
auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender();
|
||||
assert(isChannel == true);
|
||||
spatial::SpatChannelReceiveOp reciveOp =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, channelVal.getType(), channelVal);
|
||||
mapper.map(oldWeightedCompute.getBody().front().getArgument(indexOld - indexOldStart), reciveOp);
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelVal);
|
||||
mapper.map(oldBB.getArgument(indexOld - indexOldStart), reciveOp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,19 +220,28 @@ private:
|
||||
|
||||
auto weightMutableIter = toCompute.getWeightsMutable();
|
||||
for (auto weight : fromCompute.getWeights()) {
|
||||
int sizeW = toCompute.getWeights().size();
|
||||
int sizeI = toCompute.getInputs().size();
|
||||
//TODO non clonare weight gia' presenti e poi vanno rimappate le nuove OP con i nuovi weight
|
||||
auto founded = llvm::find(toCompute.getWeights(), weight);
|
||||
if(founded == toCompute.getWeights().end()){
|
||||
size_t sizeW = toCompute.getWeights().size();
|
||||
size_t sizeI = toCompute.getInputs().size();
|
||||
weightMutableIter.append(weight);
|
||||
auto last = weightMutableIter.end();
|
||||
last = std::prev(last,1);
|
||||
mapper.map(weight, last->get());
|
||||
assert(sizeW + 1 == toCompute.getWeights().size());
|
||||
assert(sizeI == toCompute.getInputs().size());
|
||||
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
|
||||
}else {
|
||||
mapper.map(weight, *founded);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
auto& toBB = toCompute.getBody().front();
|
||||
auto& fromBB = fromCompute.getBody().front();
|
||||
auto inputeArgMutable = toCompute.getInputsMutable();
|
||||
// Insert reciveOp
|
||||
rewriter.setInsertionPointToEnd(&toCompute.getBody().front());
|
||||
int newBBindex = toCompute.getBody().front().getArguments().size();
|
||||
rewriter.setInsertionPointToEnd(&toBB);
|
||||
for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) {
|
||||
if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatWeightedCompute>(arg.getDefiningOp())) {
|
||||
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
|
||||
@@ -236,30 +251,33 @@ private:
|
||||
if (channelOrLocal.isChannel) {
|
||||
spatial::SpatChannelReceiveOp reciveOp =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelOrLocal.data);
|
||||
mapper.map(fromCompute.getBody().front().getArgument(bbIndex), reciveOp.getResult());
|
||||
mapper.map(fromBB.getArgument(bbIndex), reciveOp.getResult());
|
||||
}
|
||||
else {
|
||||
mapper.map(fromCompute.getBody().front().getArgument(bbIndex), channelOrLocal.data);
|
||||
mapper.map(fromBB.getArgument(bbIndex), channelOrLocal.data);
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
int sizeW = toCompute.getWeights().size();
|
||||
int sizeI = toCompute.getInputs().size();
|
||||
auto founded = llvm::find(toCompute.getInputs(), arg);
|
||||
if(founded == toCompute.getInputs().end()){
|
||||
size_t sizeW = toCompute.getWeights().size();
|
||||
size_t sizeI = toCompute.getInputs().size();
|
||||
inputeArgMutable.append(arg);
|
||||
assert(sizeW == toCompute.getWeights().size());
|
||||
assert(sizeI + 1 == toCompute.getInputs().size());
|
||||
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
|
||||
|
||||
toCompute.getBody().front().addArgument(
|
||||
fromCompute.getBody().front().getArgument(bbIndex).getType(),loc);
|
||||
|
||||
mapper.map(fromCompute.getBody().front().getArgument(bbIndex),
|
||||
toCompute.getBody().front().getArgument(newBBindex++));
|
||||
toBB.addArgument(fromBB.getArgument(bbIndex).getType(), loc);
|
||||
mapper.map(fromBB.getArgument(bbIndex), toBB.getArguments().back());
|
||||
}else {
|
||||
auto distance = std::distance(toCompute.getInputs().begin(), founded);
|
||||
mapper.map(fromBB.getArgument(bbIndex), toBB.getArgument(distance));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto oldBBarg : fromCompute.getBody().front().getArguments())
|
||||
for (auto oldBBarg : fromBB.getArguments())
|
||||
assert(mapper.contains(oldBBarg));
|
||||
|
||||
ComputeValueResults computeValueResults;
|
||||
|
||||
Reference in New Issue
Block a user