DCP Merge status
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m29s
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m29s
This commit is contained in:
@@ -8,14 +8,19 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <fstream>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "DCPGraph/DCPAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
@@ -26,6 +31,41 @@ namespace onnx_mlir {
|
||||
namespace {
|
||||
using SpatCompute = spatial::SpatCompute;
|
||||
|
||||
void generateReport(func::FuncOp funcOp, const std::string& name) {
|
||||
std::string outputDir = getOutputDir();
|
||||
if (outputDir.empty())
|
||||
return;
|
||||
|
||||
std::string dialectsDir = outputDir + "/dialects/stats";
|
||||
createDirectory(dialectsDir);
|
||||
|
||||
std::fstream file(dialectsDir + "/" + name + ".txt", std::ios::out);
|
||||
llvm::raw_os_ostream os(file);
|
||||
|
||||
uint64_t numSpatCompute = 0;
|
||||
std::vector<uint64_t> numWeights;
|
||||
std::vector<uint64_t> numInstructions;
|
||||
|
||||
for (auto spatCompute : funcOp.getOps<SpatCompute>()) {
|
||||
numSpatCompute++;
|
||||
numWeights.push_back(spatCompute.getWeights().size());
|
||||
uint64_t numInst = 0;
|
||||
for(auto& _ : spatCompute.getRegion().front() ){
|
||||
numInst++;
|
||||
}
|
||||
numInstructions.push_back(numInst);
|
||||
}
|
||||
|
||||
for (uint64_t cI = 0; cI < numSpatCompute; ++cI) {
|
||||
os << "Compute " << cI << ":\n";
|
||||
os << "\tNumber of instructions " << numInstructions[cI] << "\n";
|
||||
os << "\tNumber of used crossbars " << numWeights[cI] << "\n";
|
||||
}
|
||||
|
||||
os.flush();
|
||||
file.close();
|
||||
}
|
||||
|
||||
struct ComputeValueResults {
|
||||
SmallVector<Value> innerValues;
|
||||
|
||||
@@ -45,9 +85,7 @@ public:
|
||||
LazyInsertComputeResult(ComputeValueResults computeValueResults,
|
||||
std::function<std::pair<Value, std::function<void(InsertPoint)>>(size_t)> channelNewInserter,
|
||||
bool isOnlyChannel)
|
||||
: computeResults(computeValueResults),
|
||||
onlyChannel(isOnlyChannel),
|
||||
channelNewInserter(channelNewInserter) {}
|
||||
: computeResults(computeValueResults), onlyChannel(isOnlyChannel), channelNewInserter(channelNewInserter) {}
|
||||
|
||||
struct ChannelOrLocalOp {
|
||||
Value data;
|
||||
@@ -107,21 +145,19 @@ public:
|
||||
size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode);
|
||||
if (!cpuToNewComputeMap.contains(cpu)) {
|
||||
ValueTypeRange<ResultRange> newComputeType = cpuToLastComputeMap.at(cpu).getResultTypes();
|
||||
auto [newCompute, computeValueResult] = createNewComputeNode(
|
||||
currentComputeNode, newComputeType, lastComputeOfCpu.contains(currentComputeNode));
|
||||
auto [newCompute, computeValueResult] =
|
||||
createNewComputeNode(currentComputeNode, newComputeType, lastComputeOfCpu.contains(currentComputeNode));
|
||||
cpuToNewComputeMap[cpu] = newCompute;
|
||||
newComputeNodeResults.insert(
|
||||
std::make_pair(currentComputeNode,
|
||||
createLazyComputeResult(
|
||||
newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
||||
newComputeNodeResults.insert(std::make_pair(
|
||||
currentComputeNode,
|
||||
createLazyComputeResult(newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
||||
}
|
||||
else {
|
||||
auto [newCompute, computeValueResult] = mergeIntoComputeNode(
|
||||
cpuToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode));
|
||||
newComputeNodeResults.insert(
|
||||
std::make_pair(currentComputeNode,
|
||||
createLazyComputeResult(
|
||||
newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
||||
newComputeNodeResults.insert(std::make_pair(
|
||||
currentComputeNode,
|
||||
createLazyComputeResult(newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,11 +168,12 @@ public:
|
||||
}
|
||||
func::FuncOp func = getOperation();
|
||||
dumpModule(cast<ModuleOp>(func->getParentOp()), "spatial1_dcp_merged");
|
||||
generateReport(func, "spatial1_dcp_merged_report");
|
||||
}
|
||||
|
||||
private:
|
||||
std::pair<SpatCompute, ComputeValueResults> createNewComputeNode(
|
||||
SpatCompute oldCompute, ValueTypeRange<ResultRange> newComputeType, bool lastCompute) {
|
||||
std::pair<SpatCompute, ComputeValueResults>
|
||||
createNewComputeNode(SpatCompute oldCompute, ValueTypeRange<ResultRange> newComputeType, bool lastCompute) {
|
||||
func::FuncOp func = getOperation();
|
||||
auto loc = func.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
@@ -161,8 +198,7 @@ private:
|
||||
|
||||
auto newCompute = SpatCompute::create(rewriter, loc, newComputeType, newComputeOperand);
|
||||
|
||||
rewriter.createBlock(
|
||||
&newCompute.getBody(), newCompute.getBody().end(), newBBOperandType, newBBLocations);
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newBBOperandType, newBBLocations);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{(int) oldCompute.getWeights().size(), (int) newBBOperandType.size()});
|
||||
|
||||
@@ -178,15 +214,14 @@ private:
|
||||
mapper.map(oldBB.getArgument(indexOld - indexOldStart), newBB.getArgument(indexNew++));
|
||||
}
|
||||
else {
|
||||
auto argWeightCompute =
|
||||
llvm::dyn_cast_if_present<SpatCompute>(oldCompute.getOperand(indexOld).getDefiningOp());
|
||||
auto argWeightCompute = llvm::dyn_cast_if_present<SpatCompute>(oldCompute.getOperand(indexOld).getDefiningOp());
|
||||
auto argResultIndex = cast<OpResult>(oldCompute.getOperand(indexOld)).getResultNumber();
|
||||
|
||||
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
|
||||
auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender(argResultIndex);
|
||||
assert(isChannel == true);
|
||||
spatial::SpatChannelReceiveOp receiveOp = spatial::SpatChannelReceiveOp::create(
|
||||
rewriter, loc, oldCompute.getOperand(indexOld).getType(), channelVal);
|
||||
spatial::SpatChannelReceiveOp receiveOp =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, oldCompute.getOperand(indexOld).getType(), channelVal);
|
||||
mapper.map(oldBB.getArgument(indexOld - indexOldStart), receiveOp);
|
||||
}
|
||||
}
|
||||
@@ -318,9 +353,8 @@ private:
|
||||
return {cast<SpatCompute>(toCompute), computeValueResults};
|
||||
}
|
||||
|
||||
LazyInsertComputeResult createLazyComputeResult(SpatCompute compute,
|
||||
ComputeValueResults computeValueResults,
|
||||
bool lastCompute) {
|
||||
LazyInsertComputeResult
|
||||
createLazyComputeResult(SpatCompute compute, ComputeValueResults computeValueResults, bool lastCompute) {
|
||||
func::FuncOp funcOp = cast<func::FuncOp>(compute->getParentOp());
|
||||
auto* context = &getContext();
|
||||
auto loc = funcOp.getLoc();
|
||||
@@ -335,11 +369,12 @@ private:
|
||||
auto channelVal = channelOp.getResult();
|
||||
auto insertVal =
|
||||
[&context, loc, computeValueResults, channelVal, resultIndex](mlir::IRRewriter::InsertPoint sendInsertPoint) {
|
||||
IRRewriter rewriter(context);
|
||||
rewriter.restoreInsertionPoint(sendInsertPoint);
|
||||
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.get(resultIndex));
|
||||
return spatSend;
|
||||
};
|
||||
IRRewriter rewriter(context);
|
||||
rewriter.restoreInsertionPoint(sendInsertPoint);
|
||||
auto spatSend =
|
||||
spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.get(resultIndex));
|
||||
return spatSend;
|
||||
};
|
||||
std::pair<Value, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {channelVal, insertVal};
|
||||
return ret;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user