Add DCP alghoritm, partial working test

This commit is contained in:
ilgeco
2026-04-07 22:05:39 +02:00
parent ef4743c986
commit ca56e3d4f1
17 changed files with 1313 additions and 33 deletions

View File

@@ -0,0 +1,318 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <cstdint>
#include <functional>
#include <memory>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
using SpatWeightedCompute = spatial::SpatWeightedCompute;
struct ComputeValueResults {
// Value yielded by the yieldOp
Value innerValue;
};
class LazyInsertComputeResult {
using InsertPoint = mlir::IRRewriter::InsertPoint;
ComputeValueResults computeResults;
Value channelNewOpVal;
bool onlyChannel;
std::function<void(InsertPoint insertPoint)> channelSendInserter;
InsertPoint insertPointSend;
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter;
public:
LazyInsertComputeResult(ComputeValueResults computeValueResults,
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter,
bool isOnlyChannel)
: computeResults(computeValueResults),
onlyChannel(isOnlyChannel),
channelSendInserter(nullptr),
insertPointSend({}),
channelNewInserter(channelNewInserter) {}
struct ChannelOrLocalOp {
Value data;
bool isChannel;
};
bool onlyChanneled() const { return onlyChannel; }
ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatWeightedCompute spatWeightedCompute) {
if (channelSendInserter == nullptr) {
auto [first, second] = channelNewInserter();
channelNewOpVal = first;
channelSendInserter = second;
auto op = computeResults.innerValue.getDefiningOp();
if (op) {
insertPointSend = InsertPoint(op->getBlock(), ++Block::iterator(op));
}
else {
auto BB = computeResults.innerValue.getParentBlock();
insertPointSend = InsertPoint(BB, BB->begin());
}
}
if (spatWeightedCompute) {
for (auto& BB : spatWeightedCompute.getBody())
if (&BB == insertPointSend.getBlock())
return {computeResults.innerValue, false};
}
channelSendInserter(insertPointSend);
return {channelNewOpVal, true};
}
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
};
struct MergeComputeNodePass : PassWrapper<MergeComputeNodePass, OperationPass<func::FuncOp>> {
private:
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
DenseMap<SpatWeightedCompute, SpatWeightedCompute> oldToNewComputeMap;
DenseMap<int64_t, SpatWeightedCompute> cputToNewComputeMap;
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodePass)
StringRef getArgument() const override { return "pim-merge-node-pass"; }
StringRef getDescription() const override {
return "Merge Spatial-Weighted-Compute-Node in order to reduce the total "
"execution time";
}
LogicalResult initialize(MLIRContext* context) override { return success(); }
void runOnOperation() override {
DCPAnalysisResult& analysisResult = getAnalysis<spatial::DCPAnalysis>().getResult();
auto& lastComputeOfCpu = analysisResult.isLastComputeOfACpu;
auto& cpuToLastComputeMap = analysisResult.cpuToLastComputeMap;
IRRewriter rewriter(&getContext());
for (auto currentComputeNode : analysisResult.dominanceOrderCompute) {
size_t cpu = analysisResult.computeToCPUMap.at(currentComputeNode);
if (!cputToNewComputeMap.contains(cpu)) {
ValueTypeRange<ResultRange> newWeightedComputeType = cpuToLastComputeMap.at(cpu).getResultTypes();
auto [newWeightedCompute, computeValueResult] = createNewComputeNode(
currentComputeNode, newWeightedComputeType, lastComputeOfCpu.contains(currentComputeNode));
cputToNewComputeMap[cpu] = newWeightedCompute;
newComputeNodeResults.insert(
std::make_pair(currentComputeNode,
createLazyComputeResult(
newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
}
else {
auto [newWeightedCompute, computeValueResult] = mergeIntoComputeNode(
cputToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode));
newComputeNodeResults.insert(
std::make_pair(currentComputeNode,
createLazyComputeResult(
newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
}
}
for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute)))
computeNodetoRemove.erase();
func::FuncOp func = getOperation();
dumpModule(cast<ModuleOp>(func->getParentOp()), "SpatialDCPMerged");
}
private:
std::pair<SpatWeightedCompute, ComputeValueResults> createNewComputeNode(
SpatWeightedCompute oldWeightedCompute, ValueTypeRange<ResultRange> newWeightedComputeType, bool lastCompute) {
func::FuncOp func = getOperation();
auto loc = func.getLoc();
IRRewriter rewriter(&getContext());
rewriter.setInsertionPoint(&*std::prev(func.getBody().front().end(), 1));
ComputeValueResults computeValueResults;
IRMapping mapper;
llvm::SmallVector<Value> newComputeOperand;
llvm::SmallVector<Type> newBBOperandType;
llvm::SmallVector<Location> newBBLocations;
for (auto arg : oldWeightedCompute.getWeights())
newComputeOperand.push_back(arg);
for (auto arg : oldWeightedCompute.getInputs())
if (!llvm::isa<SpatWeightedCompute>(arg.getDefiningOp())) {
newComputeOperand.push_back(arg);
newBBOperandType.push_back(arg.getType());
newBBLocations.push_back(loc);
}
auto newWeightedCompute = SpatWeightedCompute::create(rewriter, loc, newWeightedComputeType, newComputeOperand);
rewriter.createBlock(
&newWeightedCompute.getBody(), newWeightedCompute.getBody().end(), newBBOperandType, newBBLocations);
newWeightedCompute.getProperties().setOperandSegmentSizes(
{(int) oldWeightedCompute.getWeights().size(), (int) newBBOperandType.size()});
rewriter.setInsertionPointToEnd(&newWeightedCompute.getBody().front());
int indexNew = 0;
int indexOld = oldWeightedCompute.getWeights().size();
int indexOldStart = oldWeightedCompute.getWeights().size();
for (; indexOld < oldWeightedCompute.getNumOperands(); ++indexOld) {
if (!llvm::isa<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp())) {
mapper.map(oldWeightedCompute.getBody().front().getArgument(indexOld - indexOldStart),
newWeightedCompute.getBody().front().getArgument(indexNew++));
}
else {
auto argWeightCompute =
llvm::dyn_cast_if_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp());
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
auto [channelVal, _] = lazyArgWeight.getAsChannelValueAndInsertSender();
spatial::SpatChannelReceiveOp reciveOp =
spatial::SpatChannelReceiveOp::create(rewriter, loc, channelVal.getType(), channelVal);
mapper.map(oldWeightedCompute.getBody().front().getArgument(indexOld - indexOldStart), reciveOp);
}
}
for (auto& op : oldWeightedCompute.getOps()) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0));
if (lastCompute)
rewriter.clone(op, mapper);
}
else
rewriter.clone(op, mapper);
}
for (auto users : oldWeightedCompute->getUsers())
if (auto funcRet = dyn_cast<func::ReturnOp>(users))
funcRet.setOperand(0, newWeightedCompute.getResult(0));
oldToNewComputeMap.insert({oldWeightedCompute, newWeightedCompute});
return {cast<SpatWeightedCompute>(newWeightedCompute), computeValueResults};
}
std::pair<SpatWeightedCompute, ComputeValueResults>
mergeIntoComputeNode(SpatWeightedCompute toCompute, SpatWeightedCompute fromCompute, bool lastCompute) {
func::FuncOp func = getOperation();
auto loc = func.getLoc();
IRRewriter rewriter(&getContext());
IRMapping mapper;
auto weightMutableIter = toCompute.getWeightsMutable();
for (auto weight : fromCompute.getWeights()) {
int sizeW = toCompute.getWeights().size();
int sizeI = toCompute.getInputs().size();
weightMutableIter.append(weight);
assert(sizeW + 1 == toCompute.getWeights().size());
assert(sizeI == toCompute.getInputs().size());
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
}
auto inputeArgMutable = toCompute.getInputsMutable();
// Insert reciveOp
rewriter.setInsertionPointToEnd(&toCompute.getBody().front());
int newBBindex = toCompute.getBody().front().getArguments().size();
for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) {
if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatWeightedCompute>(arg.getDefiningOp())) {
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
LazyInsertComputeResult::ChannelOrLocalOp channelOrLocal =
lazyArgWeight.getAsChannelValueAndInsertSender(toCompute);
if (channelOrLocal.isChannel) {
spatial::SpatChannelReceiveOp reciveOp =
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelOrLocal.data);
mapper.map(fromCompute.getBody().front().getArgument(bbIndex), reciveOp.getResult());
}
else {
mapper.map(fromCompute.getBody().front().getArgument(bbIndex), channelOrLocal.data);
}
}
else {
int sizeW = toCompute.getWeights().size();
int sizeI = toCompute.getInputs().size();
inputeArgMutable.append(arg);
assert(sizeW == toCompute.getWeights().size());
assert(sizeI + 1 == toCompute.getInputs().size());
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
toCompute.getBody().front().addArgument(
fromCompute.getBody().front().getArgument(bbIndex).getType(),loc);
mapper.map(fromCompute.getBody().front().getArgument(bbIndex),
toCompute.getBody().front().getArgument(newBBindex++));
}
}
for (auto oldBBarg : fromCompute.getBody().front().getArguments())
assert(mapper.contains(oldBBarg));
ComputeValueResults computeValueResults;
for (auto& op : fromCompute.getOps()) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0));
if (lastCompute)
rewriter.clone(op, mapper);
}
else {
rewriter.clone(op, mapper);
}
}
for (auto users : fromCompute->getUsers())
if (auto funcRet = dyn_cast<func::ReturnOp>(users))
funcRet.setOperand(0, toCompute.getResult(0));
oldToNewComputeMap.insert({fromCompute, toCompute});
return {cast<SpatWeightedCompute>(toCompute), computeValueResults};
}
LazyInsertComputeResult createLazyComputeResult(SpatWeightedCompute weightedCompute,
ComputeValueResults computeValueResults,
bool lastCompute) {
func::FuncOp funcOp = cast<func::FuncOp>(weightedCompute->getParentOp());
auto* context = &getContext();
auto loc = funcOp.getLoc();
IRRewriter rewriter(context);
rewriter.setInsertionPointToStart(&funcOp.front());
auto saveInsertionPointChnNew = rewriter.saveInsertionPoint();
auto insertNew = [saveInsertionPointChnNew, context, loc, computeValueResults]() {
IRRewriter rewriter(context);
rewriter.restoreInsertionPoint(saveInsertionPointChnNew);
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(context));
auto channelVal = channelOp.getResult();
auto insertVal =
[&context, loc, computeValueResults, channelVal](mlir::IRRewriter::InsertPoint insertPointChnSend) {
IRRewriter rewriter(context);
rewriter.restoreInsertionPoint(insertPointChnSend);
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.innerValue);
return spatSend;
};
std::pair<Value, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {channelVal, insertVal};
return ret;
};
return LazyInsertComputeResult(computeValueResults, insertNew, false);
}
};
} // namespace
std::unique_ptr<Pass> createMergeComputeNodePass() { return std::make_unique<MergeComputeNodePass>(); }
} // namespace onnx_mlir