MergeDCP pass all test
All checks were successful
Validate Operations / validate-operations (push) Successful in 2h55m12s

This commit is contained in:
ilgeco
2026-04-08 20:39:01 +02:00
parent 813368f625
commit 3f870fb74b
3 changed files with 157 additions and 44 deletions

View File

@@ -5,11 +5,15 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include "Common.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
@@ -42,6 +46,7 @@ struct ONNXToSpatialPass : PassWrapper<ONNXToSpatialPass, OperationPass<ModuleOp
private:
void annotateWeightsConstants(func::FuncOp funcOp) const;
void encapsulateGlobalInstruction(func::FuncOp funcOp);
};
} // namespace
@@ -126,11 +131,88 @@ void ONNXToSpatialPass::runOnOperation() {
llvm::dbgs() << "Failed to run canonicalization cleanup, continuing...\n";
annotateWeightsConstants(*entryFunc);
encapsulateGlobalInstruction(*entryFunc);
// Dump to file for debug
dumpModule(moduleOp, "spatial");
}
template <typename T>
bool encapsulator(IRRewriter& rewriter, Location loc, Operation* inst, std::function<Value(T)> funcSource) {
if (T toRemoveOp = llvm::dyn_cast_if_present<T>(inst)) {
Value source = funcSource(toRemoveOp);
rewriter.setInsertionPointAfter(toRemoveOp);
if (isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp())) {
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), source);
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), {source.getType()}, {loc});
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) 1});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
mapper.map(source, BB->getArgument(0));
auto newInst = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newInst->getResult(0));
inst->replaceAllUsesWith(newCompute);
inst->erase();
return true;
}
}
return false;
}
bool encapsulateConcat(IRRewriter& rewriter, Location loc, Operation* inst) {
if (auto toRemoveOp = llvm::dyn_cast_if_present<tensor::ConcatOp>(inst)) {
auto sources = toRemoveOp.getInputs();
rewriter.setInsertionPointAfter(toRemoveOp);
if (llvm::any_of(
sources, [](auto source) { return isa_and_present<spatial::SpatWeightedCompute>(source.getDefiningOp()); })) {
auto newCompute = spatial::SpatWeightedCompute::create(rewriter, loc, inst->getResultTypes().front(), sources);
llvm::SmallVector<Type> sourceTypes;
llvm::SmallVector<Location> sourceLoc;
for (auto source : sources){
sourceTypes.push_back(source.getType());
sourceLoc.push_back(loc);
}
auto BB = rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), sourceTypes, sourceLoc);
newCompute.getProperties().setOperandSegmentSizes({(int) 0, (int) sources.size()});
rewriter.setInsertionPointToEnd(BB);
IRMapping mapper;
for(auto [source,bbArg] : llvm::zip(sources, BB->getArguments()))
mapper.map(source, bbArg);
auto newConcat = rewriter.clone(*inst, mapper);
spatial::SpatYieldOp::create(rewriter, loc, newConcat->getResult(0));
inst->replaceAllUsesWith(newCompute);
inst->erase();
return true;
}
}
return false;
}
// TODO what we want to keep in global?
void ONNXToSpatialPass::encapsulateGlobalInstruction(func::FuncOp funcOp) {
Location loc = funcOp.getLoc();
IRRewriter rewriter(&getContext());
bool keep = true;
while (keep) {
keep = false;
for (auto& instruction : llvm::make_early_inc_range(funcOp.getOps())) {
keep |= encapsulator<tensor::ExtractSliceOp>(
rewriter, loc, &instruction, [](tensor::ExtractSliceOp extract) { return extract.getSource(); });
keep |= encapsulator<tensor::ExpandShapeOp>(
rewriter, loc, &instruction, [](tensor::ExpandShapeOp expand) { return expand.getSrc(); });
keep |= encapsulator<ONNXTransposeOp>(
rewriter, loc, &instruction, [](ONNXTransposeOp transpose) { return transpose.getData(); });
keep |= encapsulator<tensor::CollapseShapeOp>(
rewriter, loc, &instruction, [](tensor::CollapseShapeOp collapse) { return collapse.getSrc(); });
keep |= encapsulateConcat(rewriter, loc, &instruction);
}
}
}
void ONNXToSpatialPass::annotateWeightsConstants(func::FuncOp funcOp) const {
funcOp.walk([&](arith::ConstantOp constantOp) {
bool isAlwaysWeight =

View File

@@ -1,3 +1,4 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
@@ -17,6 +18,19 @@ namespace spatial {
using namespace mlir;
SpatWeightedCompute getOriginalSpatWeightCompute(Operation* op) {
if (!op)
return {};
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) {
op = extract.getSource().getDefiningOp();
if (!op)
return {};
}
if (auto res = llvm::dyn_cast<SpatWeightedCompute>(op))
return res;
return {};
}
DCPAnalysisResult DCPAnalysis::runAnalysis() {
using EdgesIndex = std::tuple<int64_t, int64_t, int64_t>;
llvm::SmallVector<SpatWeightedCompute, 10> spatWeightedComputes;
@@ -27,8 +41,7 @@ DCPAnalysisResult DCPAnalysis::runAnalysis() {
for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) {
for (Value input : spatWeightedCompute.getInputs()) {
if (auto spatWeightedComputeArgOp = llvm::dyn_cast_if_present<SpatWeightedCompute>(input.getDefiningOp());
spatWeightedComputeArgOp) {
if (auto spatWeightedComputeArgOp = getOriginalSpatWeightCompute(input.getDefiningOp())) {
auto elemIter = llvm::find(spatWeightedComputes, spatWeightedComputeArgOp);
assert(elemIter != spatWeightedComputes.end());
auto indexStartEdge = std::distance(spatWeightedComputes.begin(), elemIter);

View File

@@ -11,9 +11,12 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include <cstddef>
#include <cstdint>
#include <functional>
#include <iterator>
#include <memory>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -62,15 +65,12 @@ public:
auto [first, second] = channelNewInserter();
channelNewOpVal = first;
channelSendInserter = second;
auto op = computeResults.innerValue.getDefiningOp();
if (op) {
insertPointSend = InsertPoint(op->getBlock(), ++Block::iterator(op));
}
else {
auto BB = computeResults.innerValue.getParentBlock();
insertPointSend = InsertPoint(BB, BB->begin());
}
}
if (!BB->empty() && isa<spatial::SpatYieldOp>(BB->back()))
insertPointSend = InsertPoint(BB, --BB->end());
else
insertPointSend = InsertPoint(BB, BB->end());
if (spatWeightedCompute) {
for (auto& BB : spatWeightedCompute.getBody())
if (&BB == insertPointSend.getBlock())
@@ -129,8 +129,10 @@ public:
}
}
for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute)))
for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute))) {
for (auto users : computeNodetoRemove->getUsers()) users->dump();
computeNodetoRemove.erase();
}
func::FuncOp func = getOperation();
dumpModule(cast<ModuleOp>(func->getParentOp()), "SpatialDCPMerged");
}
@@ -154,36 +156,40 @@ private:
newComputeOperand.push_back(arg);
for (auto arg : oldWeightedCompute.getInputs())
if (!llvm::isa<SpatWeightedCompute>(arg.getDefiningOp())) {
if (!llvm::isa_and_present<SpatWeightedCompute>(arg.getDefiningOp())) {
newComputeOperand.push_back(arg);
newBBOperandType.push_back(arg.getType());
newBBLocations.push_back(loc);
}
auto newWeightedCompute = SpatWeightedCompute::create(rewriter, loc, newWeightedComputeType, newComputeOperand);
rewriter.createBlock(
&newWeightedCompute.getBody(), newWeightedCompute.getBody().end(), newBBOperandType, newBBLocations);
newWeightedCompute.getProperties().setOperandSegmentSizes(
{(int) oldWeightedCompute.getWeights().size(), (int) newBBOperandType.size()});
rewriter.setInsertionPointToEnd(&newWeightedCompute.getBody().front());
auto& newBB = newWeightedCompute.getBody().front();
auto& oldBB = oldWeightedCompute.getBody().front();
rewriter.setInsertionPointToEnd(&newBB);
int indexNew = 0;
int indexOld = oldWeightedCompute.getWeights().size();
int indexOldStart = oldWeightedCompute.getWeights().size();
size_t indexOld = oldWeightedCompute.getWeights().size();
size_t indexOldStart = oldWeightedCompute.getWeights().size();
for (; indexOld < oldWeightedCompute.getNumOperands(); ++indexOld) {
if (!llvm::isa<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp())) {
mapper.map(oldWeightedCompute.getBody().front().getArgument(indexOld - indexOldStart),
newWeightedCompute.getBody().front().getArgument(indexNew++));
if (!llvm::isa_and_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp())) {
mapper.map(oldBB.getArgument(indexOld - indexOldStart), newBB.getArgument(indexNew++));
}
else {
auto argWeightCompute =
llvm::dyn_cast_if_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp());
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
auto [channelVal, _] = lazyArgWeight.getAsChannelValueAndInsertSender();
auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender();
assert(isChannel == true);
spatial::SpatChannelReceiveOp reciveOp =
spatial::SpatChannelReceiveOp::create(rewriter, loc, channelVal.getType(), channelVal);
mapper.map(oldWeightedCompute.getBody().front().getArgument(indexOld - indexOldStart), reciveOp);
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelVal);
mapper.map(oldBB.getArgument(indexOld - indexOldStart), reciveOp);
}
}
@@ -214,19 +220,28 @@ private:
auto weightMutableIter = toCompute.getWeightsMutable();
for (auto weight : fromCompute.getWeights()) {
int sizeW = toCompute.getWeights().size();
int sizeI = toCompute.getInputs().size();
//TODO non clonare weight gia' presenti e poi vanno rimappate le nuove OP con i nuovi weight
auto founded = llvm::find(toCompute.getWeights(), weight);
if(founded == toCompute.getWeights().end()){
size_t sizeW = toCompute.getWeights().size();
size_t sizeI = toCompute.getInputs().size();
weightMutableIter.append(weight);
auto last = weightMutableIter.end();
last = std::prev(last,1);
mapper.map(weight, last->get());
assert(sizeW + 1 == toCompute.getWeights().size());
assert(sizeI == toCompute.getInputs().size());
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
}else {
mapper.map(weight, *founded);
}
}
auto& toBB = toCompute.getBody().front();
auto& fromBB = fromCompute.getBody().front();
auto inputeArgMutable = toCompute.getInputsMutable();
// Insert reciveOp
rewriter.setInsertionPointToEnd(&toCompute.getBody().front());
int newBBindex = toCompute.getBody().front().getArguments().size();
rewriter.setInsertionPointToEnd(&toBB);
for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) {
if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatWeightedCompute>(arg.getDefiningOp())) {
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
@@ -236,30 +251,33 @@ private:
if (channelOrLocal.isChannel) {
spatial::SpatChannelReceiveOp reciveOp =
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelOrLocal.data);
mapper.map(fromCompute.getBody().front().getArgument(bbIndex), reciveOp.getResult());
mapper.map(fromBB.getArgument(bbIndex), reciveOp.getResult());
}
else {
mapper.map(fromCompute.getBody().front().getArgument(bbIndex), channelOrLocal.data);
mapper.map(fromBB.getArgument(bbIndex), channelOrLocal.data);
}
}
else {
int sizeW = toCompute.getWeights().size();
int sizeI = toCompute.getInputs().size();
auto founded = llvm::find(toCompute.getInputs(), arg);
if(founded == toCompute.getInputs().end()){
size_t sizeW = toCompute.getWeights().size();
size_t sizeI = toCompute.getInputs().size();
inputeArgMutable.append(arg);
assert(sizeW == toCompute.getWeights().size());
assert(sizeI + 1 == toCompute.getInputs().size());
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
toCompute.getBody().front().addArgument(
fromCompute.getBody().front().getArgument(bbIndex).getType(),loc);
mapper.map(fromCompute.getBody().front().getArgument(bbIndex),
toCompute.getBody().front().getArgument(newBBindex++));
toBB.addArgument(fromBB.getArgument(bbIndex).getType(), loc);
mapper.map(fromBB.getArgument(bbIndex), toBB.getArguments().back());
}else {
auto distance = std::distance(toCompute.getInputs().begin(), founded);
mapper.map(fromBB.getArgument(bbIndex), toBB.getArgument(distance));
}
}
}
for (auto oldBBarg : fromCompute.getBody().front().getArguments())
for (auto oldBBarg : fromBB.getArguments())
assert(mapper.contains(oldBBarg));
ComputeValueResults computeValueResults;