Add DCP alghoritm, partial working test
This commit is contained in:
@@ -0,0 +1,318 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Region.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/DCPGraph/DCPAnalysis.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
using SpatWeightedCompute = spatial::SpatWeightedCompute;
|
||||
|
||||
struct ComputeValueResults {
|
||||
// Value yielded by the yieldOp
|
||||
Value innerValue;
|
||||
};
|
||||
|
||||
class LazyInsertComputeResult {
|
||||
using InsertPoint = mlir::IRRewriter::InsertPoint;
|
||||
ComputeValueResults computeResults;
|
||||
Value channelNewOpVal;
|
||||
bool onlyChannel;
|
||||
std::function<void(InsertPoint insertPoint)> channelSendInserter;
|
||||
InsertPoint insertPointSend;
|
||||
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter;
|
||||
|
||||
public:
|
||||
LazyInsertComputeResult(ComputeValueResults computeValueResults,
|
||||
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter,
|
||||
bool isOnlyChannel)
|
||||
: computeResults(computeValueResults),
|
||||
onlyChannel(isOnlyChannel),
|
||||
channelSendInserter(nullptr),
|
||||
insertPointSend({}),
|
||||
channelNewInserter(channelNewInserter) {}
|
||||
|
||||
struct ChannelOrLocalOp {
|
||||
Value data;
|
||||
bool isChannel;
|
||||
};
|
||||
|
||||
bool onlyChanneled() const { return onlyChannel; }
|
||||
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatWeightedCompute spatWeightedCompute) {
|
||||
|
||||
if (channelSendInserter == nullptr) {
|
||||
auto [first, second] = channelNewInserter();
|
||||
channelNewOpVal = first;
|
||||
channelSendInserter = second;
|
||||
auto op = computeResults.innerValue.getDefiningOp();
|
||||
if (op) {
|
||||
insertPointSend = InsertPoint(op->getBlock(), ++Block::iterator(op));
|
||||
}
|
||||
else {
|
||||
auto BB = computeResults.innerValue.getParentBlock();
|
||||
insertPointSend = InsertPoint(BB, BB->begin());
|
||||
}
|
||||
}
|
||||
if (spatWeightedCompute) {
|
||||
for (auto& BB : spatWeightedCompute.getBody())
|
||||
if (&BB == insertPointSend.getBlock())
|
||||
return {computeResults.innerValue, false};
|
||||
}
|
||||
channelSendInserter(insertPointSend);
|
||||
return {channelNewOpVal, true};
|
||||
}
|
||||
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
|
||||
};
|
||||
|
||||
struct MergeComputeNodePass : PassWrapper<MergeComputeNodePass, OperationPass<func::FuncOp>> {
|
||||
|
||||
private:
|
||||
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
|
||||
DenseMap<SpatWeightedCompute, SpatWeightedCompute> oldToNewComputeMap;
|
||||
DenseMap<int64_t, SpatWeightedCompute> cputToNewComputeMap;
|
||||
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodePass)
|
||||
|
||||
StringRef getArgument() const override { return "pim-merge-node-pass"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Merge Spatial-Weighted-Compute-Node in order to reduce the total "
|
||||
"execution time";
|
||||
}
|
||||
|
||||
LogicalResult initialize(MLIRContext* context) override { return success(); }
|
||||
|
||||
void runOnOperation() override {
|
||||
DCPAnalysisResult& analysisResult = getAnalysis<spatial::DCPAnalysis>().getResult();
|
||||
auto& lastComputeOfCpu = analysisResult.isLastComputeOfACpu;
|
||||
auto& cpuToLastComputeMap = analysisResult.cpuToLastComputeMap;
|
||||
IRRewriter rewriter(&getContext());
|
||||
|
||||
for (auto currentComputeNode : analysisResult.dominanceOrderCompute) {
|
||||
size_t cpu = analysisResult.computeToCPUMap.at(currentComputeNode);
|
||||
if (!cputToNewComputeMap.contains(cpu)) {
|
||||
ValueTypeRange<ResultRange> newWeightedComputeType = cpuToLastComputeMap.at(cpu).getResultTypes();
|
||||
auto [newWeightedCompute, computeValueResult] = createNewComputeNode(
|
||||
currentComputeNode, newWeightedComputeType, lastComputeOfCpu.contains(currentComputeNode));
|
||||
cputToNewComputeMap[cpu] = newWeightedCompute;
|
||||
newComputeNodeResults.insert(
|
||||
std::make_pair(currentComputeNode,
|
||||
createLazyComputeResult(
|
||||
newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
||||
}
|
||||
else {
|
||||
auto [newWeightedCompute, computeValueResult] = mergeIntoComputeNode(
|
||||
cputToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode));
|
||||
newComputeNodeResults.insert(
|
||||
std::make_pair(currentComputeNode,
|
||||
createLazyComputeResult(
|
||||
newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
||||
}
|
||||
}
|
||||
|
||||
for (auto computeNodetoRemove : llvm::make_early_inc_range(llvm::reverse(analysisResult.dominanceOrderCompute)))
|
||||
computeNodetoRemove.erase();
|
||||
func::FuncOp func = getOperation();
|
||||
dumpModule(cast<ModuleOp>(func->getParentOp()), "SpatialDCPMerged");
|
||||
}
|
||||
|
||||
private:
|
||||
std::pair<SpatWeightedCompute, ComputeValueResults> createNewComputeNode(
|
||||
SpatWeightedCompute oldWeightedCompute, ValueTypeRange<ResultRange> newWeightedComputeType, bool lastCompute) {
|
||||
func::FuncOp func = getOperation();
|
||||
auto loc = func.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
rewriter.setInsertionPoint(&*std::prev(func.getBody().front().end(), 1));
|
||||
|
||||
ComputeValueResults computeValueResults;
|
||||
|
||||
IRMapping mapper;
|
||||
llvm::SmallVector<Value> newComputeOperand;
|
||||
llvm::SmallVector<Type> newBBOperandType;
|
||||
llvm::SmallVector<Location> newBBLocations;
|
||||
|
||||
for (auto arg : oldWeightedCompute.getWeights())
|
||||
newComputeOperand.push_back(arg);
|
||||
|
||||
for (auto arg : oldWeightedCompute.getInputs())
|
||||
if (!llvm::isa<SpatWeightedCompute>(arg.getDefiningOp())) {
|
||||
newComputeOperand.push_back(arg);
|
||||
newBBOperandType.push_back(arg.getType());
|
||||
newBBLocations.push_back(loc);
|
||||
}
|
||||
|
||||
auto newWeightedCompute = SpatWeightedCompute::create(rewriter, loc, newWeightedComputeType, newComputeOperand);
|
||||
rewriter.createBlock(
|
||||
&newWeightedCompute.getBody(), newWeightedCompute.getBody().end(), newBBOperandType, newBBLocations);
|
||||
newWeightedCompute.getProperties().setOperandSegmentSizes(
|
||||
{(int) oldWeightedCompute.getWeights().size(), (int) newBBOperandType.size()});
|
||||
rewriter.setInsertionPointToEnd(&newWeightedCompute.getBody().front());
|
||||
|
||||
int indexNew = 0;
|
||||
int indexOld = oldWeightedCompute.getWeights().size();
|
||||
int indexOldStart = oldWeightedCompute.getWeights().size();
|
||||
for (; indexOld < oldWeightedCompute.getNumOperands(); ++indexOld) {
|
||||
if (!llvm::isa<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp())) {
|
||||
mapper.map(oldWeightedCompute.getBody().front().getArgument(indexOld - indexOldStart),
|
||||
newWeightedCompute.getBody().front().getArgument(indexNew++));
|
||||
}
|
||||
else {
|
||||
auto argWeightCompute =
|
||||
llvm::dyn_cast_if_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp());
|
||||
|
||||
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
|
||||
auto [channelVal, _] = lazyArgWeight.getAsChannelValueAndInsertSender();
|
||||
spatial::SpatChannelReceiveOp reciveOp =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, channelVal.getType(), channelVal);
|
||||
mapper.map(oldWeightedCompute.getBody().front().getArgument(indexOld - indexOldStart), reciveOp);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& op : oldWeightedCompute.getOps()) {
|
||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0));
|
||||
if (lastCompute)
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
else
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
|
||||
for (auto users : oldWeightedCompute->getUsers())
|
||||
if (auto funcRet = dyn_cast<func::ReturnOp>(users))
|
||||
funcRet.setOperand(0, newWeightedCompute.getResult(0));
|
||||
|
||||
oldToNewComputeMap.insert({oldWeightedCompute, newWeightedCompute});
|
||||
return {cast<SpatWeightedCompute>(newWeightedCompute), computeValueResults};
|
||||
}
|
||||
|
||||
std::pair<SpatWeightedCompute, ComputeValueResults>
|
||||
mergeIntoComputeNode(SpatWeightedCompute toCompute, SpatWeightedCompute fromCompute, bool lastCompute) {
|
||||
func::FuncOp func = getOperation();
|
||||
auto loc = func.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
IRMapping mapper;
|
||||
|
||||
auto weightMutableIter = toCompute.getWeightsMutable();
|
||||
for (auto weight : fromCompute.getWeights()) {
|
||||
int sizeW = toCompute.getWeights().size();
|
||||
int sizeI = toCompute.getInputs().size();
|
||||
weightMutableIter.append(weight);
|
||||
assert(sizeW + 1 == toCompute.getWeights().size());
|
||||
assert(sizeI == toCompute.getInputs().size());
|
||||
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
|
||||
}
|
||||
|
||||
|
||||
auto inputeArgMutable = toCompute.getInputsMutable();
|
||||
// Insert reciveOp
|
||||
rewriter.setInsertionPointToEnd(&toCompute.getBody().front());
|
||||
int newBBindex = toCompute.getBody().front().getArguments().size();
|
||||
for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) {
|
||||
if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatWeightedCompute>(arg.getDefiningOp())) {
|
||||
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
|
||||
|
||||
LazyInsertComputeResult::ChannelOrLocalOp channelOrLocal =
|
||||
lazyArgWeight.getAsChannelValueAndInsertSender(toCompute);
|
||||
if (channelOrLocal.isChannel) {
|
||||
spatial::SpatChannelReceiveOp reciveOp =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelOrLocal.data);
|
||||
mapper.map(fromCompute.getBody().front().getArgument(bbIndex), reciveOp.getResult());
|
||||
}
|
||||
else {
|
||||
mapper.map(fromCompute.getBody().front().getArgument(bbIndex), channelOrLocal.data);
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
int sizeW = toCompute.getWeights().size();
|
||||
int sizeI = toCompute.getInputs().size();
|
||||
inputeArgMutable.append(arg);
|
||||
assert(sizeW == toCompute.getWeights().size());
|
||||
assert(sizeI + 1 == toCompute.getInputs().size());
|
||||
assert(sizeW + sizeI + 1 == toCompute.getOperands().size());
|
||||
|
||||
toCompute.getBody().front().addArgument(
|
||||
fromCompute.getBody().front().getArgument(bbIndex).getType(),loc);
|
||||
|
||||
mapper.map(fromCompute.getBody().front().getArgument(bbIndex),
|
||||
toCompute.getBody().front().getArgument(newBBindex++));
|
||||
}
|
||||
}
|
||||
|
||||
for (auto oldBBarg : fromCompute.getBody().front().getArguments())
|
||||
assert(mapper.contains(oldBBarg));
|
||||
|
||||
ComputeValueResults computeValueResults;
|
||||
for (auto& op : fromCompute.getOps()) {
|
||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0));
|
||||
if (lastCompute)
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
else {
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto users : fromCompute->getUsers())
|
||||
if (auto funcRet = dyn_cast<func::ReturnOp>(users))
|
||||
funcRet.setOperand(0, toCompute.getResult(0));
|
||||
|
||||
oldToNewComputeMap.insert({fromCompute, toCompute});
|
||||
return {cast<SpatWeightedCompute>(toCompute), computeValueResults};
|
||||
}
|
||||
|
||||
LazyInsertComputeResult createLazyComputeResult(SpatWeightedCompute weightedCompute,
|
||||
ComputeValueResults computeValueResults,
|
||||
bool lastCompute) {
|
||||
func::FuncOp funcOp = cast<func::FuncOp>(weightedCompute->getParentOp());
|
||||
auto* context = &getContext();
|
||||
auto loc = funcOp.getLoc();
|
||||
IRRewriter rewriter(context);
|
||||
|
||||
rewriter.setInsertionPointToStart(&funcOp.front());
|
||||
auto saveInsertionPointChnNew = rewriter.saveInsertionPoint();
|
||||
auto insertNew = [saveInsertionPointChnNew, context, loc, computeValueResults]() {
|
||||
IRRewriter rewriter(context);
|
||||
rewriter.restoreInsertionPoint(saveInsertionPointChnNew);
|
||||
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(context));
|
||||
auto channelVal = channelOp.getResult();
|
||||
auto insertVal =
|
||||
[&context, loc, computeValueResults, channelVal](mlir::IRRewriter::InsertPoint insertPointChnSend) {
|
||||
IRRewriter rewriter(context);
|
||||
rewriter.restoreInsertionPoint(insertPointChnSend);
|
||||
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.innerValue);
|
||||
return spatSend;
|
||||
};
|
||||
std::pair<Value, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {channelVal, insertVal};
|
||||
return ret;
|
||||
};
|
||||
return LazyInsertComputeResult(computeValueResults, insertNew, false);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> createMergeComputeNodePass() { return std::make_unique<MergeComputeNodePass>(); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Reference in New Issue
Block a user