multiple-output spat computes
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m38s

This commit is contained in:
NiccoloN
2026-04-23 09:28:57 +02:00
parent 0f13269040
commit 412ca957f6
16 changed files with 415 additions and 420 deletions

View File

@@ -32,7 +32,7 @@ def SpatChannelType : SpatType<"SpatChannel", "ch"> {
// Execution
//===----------------------------------------------------------------------===//
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
let summary = "Compute region with attached constant weights";
let arguments = (ins

View File

@@ -1,5 +1,4 @@
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
@@ -14,10 +13,7 @@
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/LogicalResult.h"
@@ -119,13 +115,10 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
}
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
auto wcomputeOp = dyn_cast<SpatWeightedCompute>(weigthedOp->getParentOp());
if (wcomputeOp)
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
if (auto computeOp = dyn_cast<SpatCompute>(weigthedOp->getParentOp()))
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp());
if (coreOp)
if (auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp()))
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
return failure();
@@ -134,7 +127,7 @@ llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigth
LogicalResult SpatWeightedMVMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
@@ -155,7 +148,7 @@ LogicalResult SpatWeightedMVMOp::verify() {
LogicalResult SpatWeightedVMMOp::verify() {
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
if (failed(matrixShapeOpt))
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op");
auto matrixShape = *matrixShapeOpt;
auto vectorShape = getInput().getType().getShape();
auto outputShape = getOutput().getType().getShape();
@@ -200,9 +193,8 @@ LogicalResult SpatVMaxOp::verify() {
return OpTrait::impl::verifySameOperandsAndResultType(*this);
}
LogicalResult SpatWeightedCompute::verify() {
// Check that it has a terminator, it is a yieldOp, and it has a single
// operand with the same type as the result
LogicalResult SpatCompute::verify() {
// Check that the terminator yields the same number and types as the compute results.
auto& block = getBody().front();
if (block.mightHaveTerminator()) {
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
@@ -257,7 +249,7 @@ LogicalResult SpatWeightedCompute::verify() {
return success();
}
LogicalResult SpatWeightedCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
Block& block = getBody().front();
if (!llvm::hasSingleElement(block))
return failure();

View File

@@ -28,7 +28,7 @@ using namespace mlir;
namespace {
struct VirtualNode {
llvm::SmallVector<size_t, 4> originalComputeIndices;
SmallVector<size_t, 4> originalComputeIndices;
Weight weight = 0;
CrossbarUsage crossbarUsage = 0;
};
@@ -50,7 +50,7 @@ struct WindowScheduleResult {
bool usedAllAvailableCpus = false;
};
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
std::map<std::pair<size_t, size_t>, Weight> edgeWeights;
for (auto [start, end, weight] : edges) {
size_t startIndex = static_cast<size_t>(start);
@@ -74,15 +74,14 @@ std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
return aggregatedEdges;
}
VirtualGraph buildInitialVirtualGraph(llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes,
llvm::ArrayRef<IndexedEdge> edges) {
VirtualGraph buildInitialVirtualGraph(ArrayRef<SpatCompute> spatComputes, ArrayRef<IndexedEdge> edges) {
VirtualGraph graph;
graph.nodes.reserve(spatWeightedComputes.size());
for (auto [index, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) {
graph.nodes.reserve(spatComputes.size());
for (auto [index, spatCompute] : llvm::enumerate(spatComputes)) {
VirtualNode node;
node.originalComputeIndices.push_back(index);
node.weight = getSpatComputeWeight(spatWeightedCompute);
node.crossbarUsage = getSpatComputeCrossbarUsage(spatWeightedCompute);
node.weight = getSpatComputeWeight(spatCompute);
node.crossbarUsage = getSpatComputeCrossbarUsage(spatCompute);
graph.nodes.push_back(std::move(node));
}
graph.edges = aggregateEdges(edges);
@@ -174,7 +173,7 @@ std::vector<size_t> selectCriticalWindow(const TimingInfo& timing, size_t window
return selected;
}
std::vector<size_t> getOriginalSignature(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes) {
std::vector<size_t> getOriginalSignature(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes) {
std::vector<size_t> signature;
for (size_t nodeIndex : selectedNodes) {
const VirtualNode& node = graph.nodes[nodeIndex];
@@ -197,8 +196,7 @@ std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::
return aggregateEdges(windowEdges);
}
WindowScheduleResult
scheduleWindow(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes, MLIRContext* context) {
WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes, MLIRContext* context) {
std::vector<Weight> windowWeights;
std::vector<CrossbarUsage> windowCrossbarUsage;
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
@@ -234,9 +232,7 @@ scheduleWindow(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes,
return result;
}
bool coarsenGraph(const VirtualGraph& graph,
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph& coarsenedGraph) {
bool coarsenGraph(const VirtualGraph& graph, ArrayRef<std::vector<size_t>> mergeGroups, VirtualGraph& coarsenedGraph) {
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
for (auto [groupIndex, mergeGroup] : llvm::enumerate(mergeGroups)) {
if (mergeGroup.size() < 2)
@@ -303,7 +299,7 @@ bool coarsenGraph(const VirtualGraph& graph,
}
bool coarsenGraphWithFallback(const VirtualGraph& graph,
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
ArrayRef<std::vector<size_t>> mergeGroups,
VirtualGraph& coarsenedGraph) {
if (coarsenGraph(graph, mergeGroups, coarsenedGraph))
return true;
@@ -330,7 +326,7 @@ bool coarsenGraphWithFallback(const VirtualGraph& graph,
return !acceptedMergeGroups.empty();
}
std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, llvm::ArrayRef<IndexedEdge> edges) {
std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, ArrayRef<IndexedEdge> edges) {
VirtualGraph graph;
graph.nodes.resize(computeCount);
graph.edges = aggregateEdges(edges);
@@ -344,22 +340,22 @@ std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, llvm::A
}
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes,
llvm::ArrayRef<IndexedEdge> originalEdges) {
ArrayRef<SpatCompute> spatComputes,
ArrayRef<IndexedEdge> originalEdges) {
DCPAnalysisResult result;
std::vector<size_t> originalToVirtualNode(spatWeightedComputes.size(), 0);
std::vector<size_t> originalToVirtualNode(spatComputes.size(), 0);
for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes))
for (size_t originalIndex : virtualNode.originalComputeIndices)
originalToVirtualNode[originalIndex] = virtualNodeIndex;
auto dominanceOrder = computeOriginalTopologicalOrder(spatWeightedComputes.size(), originalEdges);
auto dominanceOrder = computeOriginalTopologicalOrder(spatComputes.size(), originalEdges);
result.dominanceOrderCompute.reserve(dominanceOrder.size());
for (size_t originalIndex : dominanceOrder) {
SpatWeightedCompute spatWeightedCompute = spatWeightedComputes[originalIndex];
SpatCompute spatCompute = spatComputes[originalIndex];
size_t cpu = originalToVirtualNode[originalIndex];
result.dominanceOrderCompute.push_back(spatWeightedCompute);
result.computeToCpuMap[spatWeightedCompute] = cpu;
result.cpuToLastComputeMap[cpu] = spatWeightedCompute;
result.dominanceOrderCompute.push_back(spatCompute);
result.computeToCpuMap[spatCompute] = cpu;
result.cpuToLastComputeMap[cpu] = spatCompute;
}
for (auto [cpu, lastCompute] : result.cpuToLastComputeMap)
@@ -367,10 +363,8 @@ DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
return result;
}
DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes,
llvm::ArrayRef<IndexedEdge> edges,
MLIRContext* context) {
GraphDCP graphDCP(spatWeightedComputes, edges);
DCPAnalysisResult runLegacyDcp(ArrayRef<SpatCompute> spatComputes, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
GraphDCP graphDCP(spatComputes, edges);
if (coresCount.getValue() > 0)
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
graphDCP.setContext(context);
@@ -380,47 +374,41 @@ DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatWeightedCompute> spatWeightedC
} // namespace
SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) {
SpatCompute getOriginalSpatCompute(Operation* op) {
if (!op)
return {};
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) {
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
op = extract.getSource().getDefiningOp();
if (!op)
return {};
}
if (auto res = llvm::dyn_cast<SpatWeightedCompute>(op))
if (auto res = dyn_cast<SpatCompute>(op))
return res;
return {};
}
DCPAnalysisResult DCPAnalysis::run() {
SmallVector<SpatWeightedCompute, 10> spatWeightedComputes;
SmallVector<SpatCompute, 10> spatComputes;
SmallVector<IndexedEdge, 10> edges;
for (auto& region : entryOp->getRegions())
for (SpatWeightedCompute spatWeightedCompute : region.getOps<SpatWeightedCompute>())
spatWeightedComputes.push_back(spatWeightedCompute);
for (SpatCompute spatCompute : region.getOps<SpatCompute>())
spatComputes.push_back(spatCompute);
for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) {
for (Value input : spatWeightedCompute.getInputs()) {
if (auto producerCompute = getOriginalSpatWeightedCompute(input.getDefiningOp())) {
auto producerIt = llvm::find(spatWeightedComputes, producerCompute);
assert(producerIt != spatWeightedComputes.end());
auto indexStartEdge = std::distance(spatWeightedComputes.begin(), producerIt);
ResultRange outputs = producerCompute.getResults();
int64_t totalSize = 0;
for (auto output : outputs) {
ShapedType resultType = cast<ShapedType>(output.getType());
totalSize += getSizeInBytes(resultType);
}
edges.push_back({indexStartEdge, indexEndEdge, totalSize});
for (auto [indexEndEdge, spatCompute] : llvm::enumerate(spatComputes)) {
for (Value input : spatCompute.getInputs()) {
if (auto producerCompute = getOriginalSpatCompute(input.getDefiningOp())) {
auto producerIt = llvm::find(spatComputes, producerCompute);
assert(producerIt != spatComputes.end());
auto indexStartEdge = std::distance(spatComputes.begin(), producerIt);
edges.push_back({indexStartEdge, indexEndEdge, getSizeInBytes(cast<ShapedType>(input.getType()))});
}
}
}
if (dcpCriticalWindowSize.getValue() == 0)
return runLegacyDcp(spatWeightedComputes, edges, entryOp->getContext());
return runLegacyDcp(spatComputes, edges, entryOp->getContext());
VirtualGraph virtualGraph = buildInitialVirtualGraph(spatWeightedComputes, edges);
VirtualGraph virtualGraph = buildInitialVirtualGraph(spatComputes, edges);
std::set<std::vector<size_t>> seenCriticalWindows;
while (virtualGraph.nodes.size() > 1) {
TimingInfo timing = computeTiming(virtualGraph);
@@ -446,7 +434,7 @@ DCPAnalysisResult DCPAnalysis::run() {
break;
}
return buildResultFromVirtualGraph(virtualGraph, spatWeightedComputes, edges);
return buildResultFromVirtualGraph(virtualGraph, spatComputes, edges);
}
} // namespace spatial

View File

@@ -10,10 +10,10 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
struct DCPAnalysisResult {
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;
llvm::DenseMap<onnx_mlir::spatial::SpatWeightedCompute, size_t> computeToCpuMap;
llvm::DenseSet<onnx_mlir::spatial::SpatWeightedCompute> isLastComputeOfCpu;
llvm::DenseMap<size_t, onnx_mlir::spatial::SpatWeightedCompute> cpuToLastComputeMap;
std::vector<onnx_mlir::spatial::SpatCompute> dominanceOrderCompute;
llvm::DenseMap<onnx_mlir::spatial::SpatCompute, size_t> computeToCpuMap;
llvm::DenseSet<onnx_mlir::spatial::SpatCompute> isLastComputeOfCpu;
llvm::DenseMap<size_t, onnx_mlir::spatial::SpatCompute> cpuToLastComputeMap;
};
namespace onnx_mlir {

View File

@@ -1260,7 +1260,7 @@ DCPAnalysisResult GraphDCP::getResult() {
auto dominanceOrder = dcp_graph::collectDominanceOrder(getRoots(), nodes.size());
ret.dominanceOrderCompute.reserve(dominanceOrder.size());
for (auto elem : dominanceOrder)
ret.dominanceOrderCompute.push_back(elem->getSpatWeightedCompute());
ret.dominanceOrderCompute.push_back(elem->getSpatCompute());
for (CPU cpu = 0; cpu < getLastCpu(); ++cpu) {
const CpuTaskList* tasks = findCpuTasks(cpu);
@@ -1268,10 +1268,10 @@ DCPAnalysisResult GraphDCP::getResult() {
continue;
size_t i = 0;
for (auto node : *tasks) {
ret.computeToCpuMap[node->getSpatWeightedCompute()] = cpu;
ret.computeToCpuMap[node->getSpatCompute()] = cpu;
if (i++ == tasks->size() - 1) {
ret.isLastComputeOfCpu.insert(node->getSpatWeightedCompute());
ret.cpuToLastComputeMap[cpu] = node->getSpatWeightedCompute();
ret.isLastComputeOfCpu.insert(node->getSpatCompute());
ret.cpuToLastComputeMap[cpu] = node->getSpatCompute();
}
}
}

View File

@@ -115,11 +115,11 @@ private:
public:
void runDcp();
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatWeightedCompute> spatWeightedComputes,
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatCompute> spatComputes,
llvm::ArrayRef<IndexedEdge> edges)
: nodes(), cpuTasks(), cpuCrossbarUsage() {
for (auto spatWeightedCompute : spatWeightedComputes)
nodes.emplace_back(spatWeightedCompute);
for (auto spatCompute : spatComputes)
nodes.emplace_back(spatCompute);
for (auto [start, end, weight] : edges)
makeEdge(start, end, weight);
}

View File

@@ -8,7 +8,7 @@
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
class TaskDCP : public onnx_mlir::LabeledListNode<TaskDCP> {
onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute;
onnx_mlir::spatial::SpatCompute spatCompute;
Time aest;
Time alst;
std::optional<CPU> scheduledCpu;
@@ -38,22 +38,22 @@ public:
std::vector<Edge> parents;
std::vector<Edge> children;
TaskDCP() = default;
TaskDCP(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute)
TaskDCP(onnx_mlir::spatial::SpatCompute spatCompute)
: onnx_mlir::LabeledListNode<TaskDCP>(),
spatWeightedCompute(spatWeightedCompute),
spatCompute(spatCompute),
aest(0),
alst(0),
scheduledCpu(),
weight(getSpatComputeWeight(spatWeightedCompute)),
weight(getSpatComputeWeight(spatCompute)),
baseWeight(weight),
crossbarUsage(getSpatComputeCrossbarUsage(spatWeightedCompute)),
crossbarUsage(getSpatComputeCrossbarUsage(spatCompute)),
syntheticId(-1),
parents(),
children() {}
TaskDCP(int64_t id, Weight weight, CrossbarUsage crossbarUsage = 0)
: onnx_mlir::LabeledListNode<TaskDCP>(),
spatWeightedCompute(),
spatCompute(),
aest(0),
alst(0),
scheduledCpu(),
@@ -90,14 +90,14 @@ public:
void setAlst(Time value) { alst = value; }
bool hasDescendant(TaskDCP* child);
int64_t Id() const {
if (spatWeightedCompute)
return reinterpret_cast<int64_t>(spatWeightedCompute.getAsOpaquePointer());
if (spatCompute)
return reinterpret_cast<int64_t>(spatCompute.getAsOpaquePointer());
return syntheticId;
}
bool isCriticalPath() const { return alst == aest; }
bool isScheduled() const { return scheduledCpu.has_value(); }
onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute() const { return spatWeightedCompute; }
onnx_mlir::spatial::SpatCompute getSpatCompute() const { return spatCompute; }
void setFlag(long long val) { flag = val; }
long long getFlag() const { return flag; }

View File

@@ -92,18 +92,18 @@ inline T subtractOrZero(T lhs, T rhs) {
inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); }
inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) {
inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatCompute spatCompute) {
constexpr Weight kOperationWeight = 100;
Weight numOperations = 0;
for (auto& block : spatWeightedCompute.getBody())
for (auto& block : spatCompute.getBody())
for ([[maybe_unused]] auto& op : block)
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
return checkedMultiply(numOperations, kOperationWeight);
}
inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) {
inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatCompute spatCompute) {
CrossbarUsage crossbarUsage = 0;
for (auto& region : spatWeightedCompute.getBody())
for (auto& region : spatCompute.getBody())
for (auto& inst : region)
if (llvm::isa<onnx_mlir::spatial::SpatWeightedVMMOp>(inst))
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));

View File

@@ -24,30 +24,29 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
using SpatWeightedCompute = spatial::SpatWeightedCompute;
using SpatCompute = spatial::SpatCompute;
struct ComputeValueResults {
// Value yielded by the yieldOp
Value innerValue;
SmallVector<Value> innerValues;
Value get(size_t resultIndex) const {
assert(resultIndex < innerValues.size() && "compute result index out of range");
return innerValues[resultIndex];
}
};
class LazyInsertComputeResult {
using InsertPoint = mlir::IRRewriter::InsertPoint;
ComputeValueResults computeResults;
Value channelValue;
bool onlyChannel;
std::function<void(InsertPoint insertPoint)> channelSendInserter;
InsertPoint sendInsertPoint;
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter;
std::function<std::pair<Value, std::function<void(InsertPoint)>>(size_t)> channelNewInserter;
public:
LazyInsertComputeResult(ComputeValueResults computeValueResults,
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter,
std::function<std::pair<Value, std::function<void(InsertPoint)>>(size_t)> channelNewInserter,
bool isOnlyChannel)
: computeResults(computeValueResults),
onlyChannel(isOnlyChannel),
channelSendInserter(nullptr),
sendInsertPoint({}),
channelNewInserter(channelNewInserter) {}
struct ChannelOrLocalOp {
@@ -57,12 +56,12 @@ public:
bool onlyChanneled() const { return onlyChannel; }
ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatWeightedCompute currentCompute) {
ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatCompute currentCompute, size_t resultIndex) {
Value innerValue = computeResults.get(resultIndex);
auto [newChannelValue, senderInserter] = channelNewInserter();
channelValue = newChannelValue;
channelSendInserter = senderInserter;
auto* block = computeResults.innerValue.getParentBlock();
auto [channelValue, channelSendInserter] = channelNewInserter(resultIndex);
InsertPoint sendInsertPoint;
auto* block = innerValue.getParentBlock();
if (!block->empty() && isa<spatial::SpatYieldOp>(block->back()))
sendInsertPoint = InsertPoint(block, --block->end());
else
@@ -70,28 +69,30 @@ public:
if (currentCompute) {
for (auto& block : currentCompute.getBody())
if (&block == sendInsertPoint.getBlock())
return {computeResults.innerValue, false};
return {innerValue, false};
}
channelSendInserter(sendInsertPoint);
return {channelValue, true};
}
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
ChannelOrLocalOp getAsChannelValueAndInsertSender(size_t resultIndex) {
return getAsChannelValueAndInsertSender({}, resultIndex);
}
};
struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<func::FuncOp>> {
private:
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
DenseMap<SpatWeightedCompute, SpatWeightedCompute> oldToNewComputeMap;
DenseMap<int64_t, SpatWeightedCompute> cpuToNewComputeMap;
DenseMap<SpatCompute, LazyInsertComputeResult> newComputeNodeResults;
DenseMap<SpatCompute, SpatCompute> oldToNewComputeMap;
DenseMap<int64_t, SpatCompute> cpuToNewComputeMap;
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass)
StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; }
StringRef getDescription() const override {
return "Merge Spatial-Weighted-Compute-Nodes in order to reduce the total "
return "Merge Spatial-Compute-Nodes in order to reduce the total "
"execution time";
}
@@ -105,22 +106,22 @@ public:
for (auto currentComputeNode : analysisResult.dominanceOrderCompute) {
size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode);
if (!cpuToNewComputeMap.contains(cpu)) {
ValueTypeRange<ResultRange> newWeightedComputeType = cpuToLastComputeMap.at(cpu).getResultTypes();
auto [newWeightedCompute, computeValueResult] = createNewComputeNode(
currentComputeNode, newWeightedComputeType, lastComputeOfCpu.contains(currentComputeNode));
cpuToNewComputeMap[cpu] = newWeightedCompute;
ValueTypeRange<ResultRange> newComputeType = cpuToLastComputeMap.at(cpu).getResultTypes();
auto [newCompute, computeValueResult] = createNewComputeNode(
currentComputeNode, newComputeType, lastComputeOfCpu.contains(currentComputeNode));
cpuToNewComputeMap[cpu] = newCompute;
newComputeNodeResults.insert(
std::make_pair(currentComputeNode,
createLazyComputeResult(
newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
}
else {
auto [newWeightedCompute, computeValueResult] = mergeIntoComputeNode(
auto [newCompute, computeValueResult] = mergeIntoComputeNode(
cpuToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode));
newComputeNodeResults.insert(
std::make_pair(currentComputeNode,
createLazyComputeResult(
newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
}
}
@@ -134,8 +135,8 @@ public:
}
private:
std::pair<SpatWeightedCompute, ComputeValueResults> createNewComputeNode(
SpatWeightedCompute oldWeightedCompute, ValueTypeRange<ResultRange> newWeightedComputeType, bool lastCompute) {
std::pair<SpatCompute, ComputeValueResults> createNewComputeNode(
SpatCompute oldCompute, ValueTypeRange<ResultRange> newComputeType, bool lastCompute) {
func::FuncOp func = getOperation();
auto loc = func.getLoc();
IRRewriter rewriter(&getContext());
@@ -148,50 +149,53 @@ private:
llvm::SmallVector<Type> newBBOperandType;
llvm::SmallVector<Location> newBBLocations;
for (auto arg : oldWeightedCompute.getWeights())
for (auto arg : oldCompute.getWeights())
newComputeOperand.push_back(arg);
for (auto arg : oldWeightedCompute.getInputs())
if (!llvm::isa_and_present<SpatWeightedCompute>(arg.getDefiningOp())) {
for (auto arg : oldCompute.getInputs())
if (!llvm::isa_and_present<SpatCompute>(arg.getDefiningOp())) {
newComputeOperand.push_back(arg);
newBBOperandType.push_back(arg.getType());
newBBLocations.push_back(loc);
}
auto newWeightedCompute = SpatWeightedCompute::create(rewriter, loc, newWeightedComputeType, newComputeOperand);
auto newCompute = SpatCompute::create(rewriter, loc, newComputeType, newComputeOperand);
rewriter.createBlock(
&newWeightedCompute.getBody(), newWeightedCompute.getBody().end(), newBBOperandType, newBBLocations);
newWeightedCompute.getProperties().setOperandSegmentSizes(
{(int) oldWeightedCompute.getWeights().size(), (int) newBBOperandType.size()});
&newCompute.getBody(), newCompute.getBody().end(), newBBOperandType, newBBLocations);
newCompute.getProperties().setOperandSegmentSizes(
{(int) oldCompute.getWeights().size(), (int) newBBOperandType.size()});
auto& newBB = newWeightedCompute.getBody().front();
auto& oldBB = oldWeightedCompute.getBody().front();
auto& newBB = newCompute.getBody().front();
auto& oldBB = oldCompute.getBody().front();
rewriter.setInsertionPointToEnd(&newBB);
int indexNew = 0;
size_t indexOld = oldWeightedCompute.getWeights().size();
size_t indexOldStart = oldWeightedCompute.getWeights().size();
for (; indexOld < oldWeightedCompute.getNumOperands(); ++indexOld) {
if (!llvm::isa_and_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp())) {
size_t indexOld = oldCompute.getWeights().size();
size_t indexOldStart = oldCompute.getWeights().size();
for (; indexOld < oldCompute.getNumOperands(); ++indexOld) {
if (!llvm::isa_and_present<SpatCompute>(oldCompute.getOperand(indexOld).getDefiningOp())) {
mapper.map(oldBB.getArgument(indexOld - indexOldStart), newBB.getArgument(indexNew++));
}
else {
auto argWeightCompute =
llvm::dyn_cast_if_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp());
llvm::dyn_cast_if_present<SpatCompute>(oldCompute.getOperand(indexOld).getDefiningOp());
auto argResultIndex = cast<OpResult>(oldCompute.getOperand(indexOld)).getResultNumber();
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender();
auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender(argResultIndex);
assert(isChannel == true);
spatial::SpatChannelReceiveOp receiveOp =
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelVal);
spatial::SpatChannelReceiveOp receiveOp = spatial::SpatChannelReceiveOp::create(
rewriter, loc, oldCompute.getOperand(indexOld).getType(), channelVal);
mapper.map(oldBB.getArgument(indexOld - indexOldStart), receiveOp);
}
}
for (auto& op : oldWeightedCompute.getOps()) {
for (auto& op : oldCompute.getOps()) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0));
computeValueResults.innerValues.reserve(yield.getNumOperands());
for (Value yieldOperand : yield.getOperands())
computeValueResults.innerValues.push_back(mapper.lookup(yieldOperand));
if (lastCompute)
rewriter.clone(op, mapper);
}
@@ -199,16 +203,18 @@ private:
rewriter.clone(op, mapper);
}
for (auto& use : llvm::make_early_inc_range(oldWeightedCompute->getUses()))
if (isa<func::ReturnOp>(use.getOwner()))
use.assign(newWeightedCompute.getResult(0));
for (auto& use : llvm::make_early_inc_range(oldCompute->getUses()))
if (isa<func::ReturnOp>(use.getOwner())) {
auto resultIndex = cast<OpResult>(use.get()).getResultNumber();
use.assign(newCompute.getResult(resultIndex));
}
oldToNewComputeMap.insert({oldWeightedCompute, newWeightedCompute});
return {cast<SpatWeightedCompute>(newWeightedCompute), computeValueResults};
oldToNewComputeMap.insert({oldCompute, newCompute});
return {cast<SpatCompute>(newCompute), computeValueResults};
}
std::pair<SpatWeightedCompute, ComputeValueResults>
mergeIntoComputeNode(SpatWeightedCompute toCompute, SpatWeightedCompute fromCompute, bool lastCompute) {
std::pair<SpatCompute, ComputeValueResults>
mergeIntoComputeNode(SpatCompute toCompute, SpatCompute fromCompute, bool lastCompute) {
func::FuncOp func = getOperation();
auto loc = func.getLoc();
IRRewriter rewriter(&getContext());
@@ -239,14 +245,15 @@ private:
// Insert receiveOp
rewriter.setInsertionPointToEnd(&toBB);
for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) {
if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatWeightedCompute>(arg.getDefiningOp())) {
if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatCompute>(arg.getDefiningOp())) {
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
auto argResultIndex = cast<OpResult>(arg).getResultNumber();
LazyInsertComputeResult::ChannelOrLocalOp channelOrLocal =
lazyArgWeight.getAsChannelValueAndInsertSender(toCompute);
lazyArgWeight.getAsChannelValueAndInsertSender(toCompute, argResultIndex);
if (channelOrLocal.isChannel) {
spatial::SpatChannelReceiveOp receiveOp =
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelOrLocal.data);
spatial::SpatChannelReceiveOp::create(rewriter, loc, arg.getType(), channelOrLocal.data);
mapper.map(fromBB.getArgument(bbIndex), receiveOp.getResult());
}
else {
@@ -286,7 +293,9 @@ private:
};
for (auto& op : fromCompute.getOps()) {
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0));
computeValueResults.innerValues.reserve(yield.getNumOperands());
for (Value yieldOperand : yield.getOperands())
computeValueResults.innerValues.push_back(mapper.lookup(yieldOperand));
if (lastCompute)
rewriter.clone(op, mapper);
}
@@ -299,33 +308,36 @@ private:
}
}
for (auto users : fromCompute->getUsers())
if (auto funcRet = dyn_cast<func::ReturnOp>(users))
funcRet.setOperand(0, toCompute.getResult(0));
for (auto& use : llvm::make_early_inc_range(fromCompute->getUses()))
if (isa<func::ReturnOp>(use.getOwner())) {
auto resultIndex = cast<OpResult>(use.get()).getResultNumber();
use.assign(toCompute.getResult(resultIndex));
}
oldToNewComputeMap.insert({fromCompute, toCompute});
return {cast<SpatWeightedCompute>(toCompute), computeValueResults};
return {cast<SpatCompute>(toCompute), computeValueResults};
}
LazyInsertComputeResult createLazyComputeResult(SpatWeightedCompute weightedCompute,
LazyInsertComputeResult createLazyComputeResult(SpatCompute compute,
ComputeValueResults computeValueResults,
bool lastCompute) {
func::FuncOp funcOp = cast<func::FuncOp>(weightedCompute->getParentOp());
func::FuncOp funcOp = cast<func::FuncOp>(compute->getParentOp());
auto* context = &getContext();
auto loc = funcOp.getLoc();
IRRewriter rewriter(context);
rewriter.setInsertionPointToStart(&funcOp.front());
auto savedChannelInsertPoint = rewriter.saveInsertionPoint();
auto insertNew = [savedChannelInsertPoint, context, loc, computeValueResults]() {
auto insertNew = [savedChannelInsertPoint, context, loc, computeValueResults](size_t resultIndex) {
IRRewriter rewriter(context);
rewriter.restoreInsertionPoint(savedChannelInsertPoint);
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(context));
auto channelVal = channelOp.getResult();
auto insertVal = [&context, loc, computeValueResults, channelVal](mlir::IRRewriter::InsertPoint sendInsertPoint) {
auto insertVal =
[&context, loc, computeValueResults, channelVal, resultIndex](mlir::IRRewriter::InsertPoint sendInsertPoint) {
IRRewriter rewriter(context);
rewriter.restoreInsertionPoint(sendInsertPoint);
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.innerValue);
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.get(resultIndex));
return spatSend;
};
std::pair<Value, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {channelVal, insertVal};