multiple-output spat computes
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m38s
All checks were successful
Validate Operations / validate-operations (push) Successful in 22m38s
This commit is contained in:
@@ -32,7 +32,7 @@ def SpatChannelType : SpatType<"SpatChannel", "ch"> {
|
||||
// Execution
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def SpatWeightedCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
def SpatCompute : SpatOp<"compute", [SingleBlock, AttrSizedOperandSegments]> {
|
||||
let summary = "Compute region with attached constant weights";
|
||||
|
||||
let arguments = (ins
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
@@ -14,10 +13,7 @@
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
@@ -119,13 +115,10 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter,
|
||||
}
|
||||
|
||||
llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigthedOp, size_t weightIndex) {
|
||||
auto wcomputeOp = dyn_cast<SpatWeightedCompute>(weigthedOp->getParentOp());
|
||||
if (wcomputeOp)
|
||||
return cast<ShapedType>(wcomputeOp.getWeights()[weightIndex].getType()).getShape();
|
||||
if (auto computeOp = dyn_cast<SpatCompute>(weigthedOp->getParentOp()))
|
||||
return cast<ShapedType>(computeOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp());
|
||||
|
||||
if (coreOp)
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(weigthedOp->getParentOp()))
|
||||
return cast<ShapedType>(coreOp.getWeights()[weightIndex].getType()).getShape();
|
||||
|
||||
return failure();
|
||||
@@ -134,7 +127,7 @@ llvm::FailureOr<ArrayRef<int64_t>> getWeightShapeForWeightedOp(Operation* weigth
|
||||
LogicalResult SpatWeightedMVMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedMVMOp was not within a SpatWeightedCompute or Core op");
|
||||
return emitError("SpatWeightedMVMOp was not within a SpatCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
@@ -155,7 +148,7 @@ LogicalResult SpatWeightedMVMOp::verify() {
|
||||
LogicalResult SpatWeightedVMMOp::verify() {
|
||||
auto matrixShapeOpt = getWeightShapeForWeightedOp(this->getOperation(), this->getWeightIndex());
|
||||
if (failed(matrixShapeOpt))
|
||||
return emitError("SpatWeightedVMMOp was not within a SpatWeightedCompute or Core op");
|
||||
return emitError("SpatWeightedVMMOp was not within a SpatCompute or Core op");
|
||||
auto matrixShape = *matrixShapeOpt;
|
||||
auto vectorShape = getInput().getType().getShape();
|
||||
auto outputShape = getOutput().getType().getShape();
|
||||
@@ -200,9 +193,8 @@ LogicalResult SpatVMaxOp::verify() {
|
||||
return OpTrait::impl::verifySameOperandsAndResultType(*this);
|
||||
}
|
||||
|
||||
LogicalResult SpatWeightedCompute::verify() {
|
||||
// Check that it has a terminator, it is a yieldOp, and it has a single
|
||||
// operand with the same type as the result
|
||||
LogicalResult SpatCompute::verify() {
|
||||
// Check that the terminator yields the same number and types as the compute results.
|
||||
auto& block = getBody().front();
|
||||
if (block.mightHaveTerminator()) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
@@ -257,7 +249,7 @@ LogicalResult SpatWeightedCompute::verify() {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult SpatWeightedCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||
LogicalResult SpatCompute::fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult>& results) {
|
||||
Block& block = getBody().front();
|
||||
if (!llvm::hasSingleElement(block))
|
||||
return failure();
|
||||
|
||||
@@ -28,7 +28,7 @@ using namespace mlir;
|
||||
namespace {
|
||||
|
||||
struct VirtualNode {
|
||||
llvm::SmallVector<size_t, 4> originalComputeIndices;
|
||||
SmallVector<size_t, 4> originalComputeIndices;
|
||||
Weight weight = 0;
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
};
|
||||
@@ -50,7 +50,7 @@ struct WindowScheduleResult {
|
||||
bool usedAllAvailableCpus = false;
|
||||
};
|
||||
|
||||
std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
|
||||
std::vector<IndexedEdge> aggregateEdges(ArrayRef<IndexedEdge> edges) {
|
||||
std::map<std::pair<size_t, size_t>, Weight> edgeWeights;
|
||||
for (auto [start, end, weight] : edges) {
|
||||
size_t startIndex = static_cast<size_t>(start);
|
||||
@@ -74,15 +74,14 @@ std::vector<IndexedEdge> aggregateEdges(llvm::ArrayRef<IndexedEdge> edges) {
|
||||
return aggregatedEdges;
|
||||
}
|
||||
|
||||
VirtualGraph buildInitialVirtualGraph(llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes,
|
||||
llvm::ArrayRef<IndexedEdge> edges) {
|
||||
VirtualGraph buildInitialVirtualGraph(ArrayRef<SpatCompute> spatComputes, ArrayRef<IndexedEdge> edges) {
|
||||
VirtualGraph graph;
|
||||
graph.nodes.reserve(spatWeightedComputes.size());
|
||||
for (auto [index, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) {
|
||||
graph.nodes.reserve(spatComputes.size());
|
||||
for (auto [index, spatCompute] : llvm::enumerate(spatComputes)) {
|
||||
VirtualNode node;
|
||||
node.originalComputeIndices.push_back(index);
|
||||
node.weight = getSpatComputeWeight(spatWeightedCompute);
|
||||
node.crossbarUsage = getSpatComputeCrossbarUsage(spatWeightedCompute);
|
||||
node.weight = getSpatComputeWeight(spatCompute);
|
||||
node.crossbarUsage = getSpatComputeCrossbarUsage(spatCompute);
|
||||
graph.nodes.push_back(std::move(node));
|
||||
}
|
||||
graph.edges = aggregateEdges(edges);
|
||||
@@ -174,7 +173,7 @@ std::vector<size_t> selectCriticalWindow(const TimingInfo& timing, size_t window
|
||||
return selected;
|
||||
}
|
||||
|
||||
std::vector<size_t> getOriginalSignature(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes) {
|
||||
std::vector<size_t> getOriginalSignature(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes) {
|
||||
std::vector<size_t> signature;
|
||||
for (size_t nodeIndex : selectedNodes) {
|
||||
const VirtualNode& node = graph.nodes[nodeIndex];
|
||||
@@ -197,8 +196,7 @@ std::vector<IndexedEdge> buildWindowEdges(const VirtualGraph& graph, const std::
|
||||
return aggregateEdges(windowEdges);
|
||||
}
|
||||
|
||||
WindowScheduleResult
|
||||
scheduleWindow(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes, MLIRContext* context) {
|
||||
WindowScheduleResult scheduleWindow(const VirtualGraph& graph, ArrayRef<size_t> selectedNodes, MLIRContext* context) {
|
||||
std::vector<Weight> windowWeights;
|
||||
std::vector<CrossbarUsage> windowCrossbarUsage;
|
||||
std::vector<int64_t> nodeToWindowIndex(graph.nodes.size(), -1);
|
||||
@@ -234,9 +232,7 @@ scheduleWindow(const VirtualGraph& graph, llvm::ArrayRef<size_t> selectedNodes,
|
||||
return result;
|
||||
}
|
||||
|
||||
bool coarsenGraph(const VirtualGraph& graph,
|
||||
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
|
||||
VirtualGraph& coarsenedGraph) {
|
||||
bool coarsenGraph(const VirtualGraph& graph, ArrayRef<std::vector<size_t>> mergeGroups, VirtualGraph& coarsenedGraph) {
|
||||
std::vector<int64_t> nodeToMergeGroup(graph.nodes.size(), -1);
|
||||
for (auto [groupIndex, mergeGroup] : llvm::enumerate(mergeGroups)) {
|
||||
if (mergeGroup.size() < 2)
|
||||
@@ -303,7 +299,7 @@ bool coarsenGraph(const VirtualGraph& graph,
|
||||
}
|
||||
|
||||
bool coarsenGraphWithFallback(const VirtualGraph& graph,
|
||||
llvm::ArrayRef<std::vector<size_t>> mergeGroups,
|
||||
ArrayRef<std::vector<size_t>> mergeGroups,
|
||||
VirtualGraph& coarsenedGraph) {
|
||||
if (coarsenGraph(graph, mergeGroups, coarsenedGraph))
|
||||
return true;
|
||||
@@ -330,7 +326,7 @@ bool coarsenGraphWithFallback(const VirtualGraph& graph,
|
||||
return !acceptedMergeGroups.empty();
|
||||
}
|
||||
|
||||
std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, llvm::ArrayRef<IndexedEdge> edges) {
|
||||
std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, ArrayRef<IndexedEdge> edges) {
|
||||
VirtualGraph graph;
|
||||
graph.nodes.resize(computeCount);
|
||||
graph.edges = aggregateEdges(edges);
|
||||
@@ -344,22 +340,22 @@ std::vector<size_t> computeOriginalTopologicalOrder(size_t computeCount, llvm::A
|
||||
}
|
||||
|
||||
DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
|
||||
llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes,
|
||||
llvm::ArrayRef<IndexedEdge> originalEdges) {
|
||||
ArrayRef<SpatCompute> spatComputes,
|
||||
ArrayRef<IndexedEdge> originalEdges) {
|
||||
DCPAnalysisResult result;
|
||||
std::vector<size_t> originalToVirtualNode(spatWeightedComputes.size(), 0);
|
||||
std::vector<size_t> originalToVirtualNode(spatComputes.size(), 0);
|
||||
for (auto [virtualNodeIndex, virtualNode] : llvm::enumerate(graph.nodes))
|
||||
for (size_t originalIndex : virtualNode.originalComputeIndices)
|
||||
originalToVirtualNode[originalIndex] = virtualNodeIndex;
|
||||
|
||||
auto dominanceOrder = computeOriginalTopologicalOrder(spatWeightedComputes.size(), originalEdges);
|
||||
auto dominanceOrder = computeOriginalTopologicalOrder(spatComputes.size(), originalEdges);
|
||||
result.dominanceOrderCompute.reserve(dominanceOrder.size());
|
||||
for (size_t originalIndex : dominanceOrder) {
|
||||
SpatWeightedCompute spatWeightedCompute = spatWeightedComputes[originalIndex];
|
||||
SpatCompute spatCompute = spatComputes[originalIndex];
|
||||
size_t cpu = originalToVirtualNode[originalIndex];
|
||||
result.dominanceOrderCompute.push_back(spatWeightedCompute);
|
||||
result.computeToCpuMap[spatWeightedCompute] = cpu;
|
||||
result.cpuToLastComputeMap[cpu] = spatWeightedCompute;
|
||||
result.dominanceOrderCompute.push_back(spatCompute);
|
||||
result.computeToCpuMap[spatCompute] = cpu;
|
||||
result.cpuToLastComputeMap[cpu] = spatCompute;
|
||||
}
|
||||
|
||||
for (auto [cpu, lastCompute] : result.cpuToLastComputeMap)
|
||||
@@ -367,10 +363,8 @@ DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph,
|
||||
return result;
|
||||
}
|
||||
|
||||
DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatWeightedCompute> spatWeightedComputes,
|
||||
llvm::ArrayRef<IndexedEdge> edges,
|
||||
MLIRContext* context) {
|
||||
GraphDCP graphDCP(spatWeightedComputes, edges);
|
||||
DCPAnalysisResult runLegacyDcp(ArrayRef<SpatCompute> spatComputes, ArrayRef<IndexedEdge> edges, MLIRContext* context) {
|
||||
GraphDCP graphDCP(spatComputes, edges);
|
||||
if (coresCount.getValue() > 0)
|
||||
graphDCP.setMaxCpuCount(static_cast<int>(coresCount.getValue()));
|
||||
graphDCP.setContext(context);
|
||||
@@ -380,47 +374,41 @@ DCPAnalysisResult runLegacyDcp(llvm::ArrayRef<SpatWeightedCompute> spatWeightedC
|
||||
|
||||
} // namespace
|
||||
|
||||
SpatWeightedCompute getOriginalSpatWeightedCompute(Operation* op) {
|
||||
SpatCompute getOriginalSpatCompute(Operation* op) {
|
||||
if (!op)
|
||||
return {};
|
||||
while (auto extract = llvm::dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
while (auto extract = dyn_cast<tensor::ExtractSliceOp>(op)) {
|
||||
op = extract.getSource().getDefiningOp();
|
||||
if (!op)
|
||||
return {};
|
||||
}
|
||||
if (auto res = llvm::dyn_cast<SpatWeightedCompute>(op))
|
||||
if (auto res = dyn_cast<SpatCompute>(op))
|
||||
return res;
|
||||
return {};
|
||||
}
|
||||
|
||||
DCPAnalysisResult DCPAnalysis::run() {
|
||||
SmallVector<SpatWeightedCompute, 10> spatWeightedComputes;
|
||||
SmallVector<SpatCompute, 10> spatComputes;
|
||||
SmallVector<IndexedEdge, 10> edges;
|
||||
for (auto& region : entryOp->getRegions())
|
||||
for (SpatWeightedCompute spatWeightedCompute : region.getOps<SpatWeightedCompute>())
|
||||
spatWeightedComputes.push_back(spatWeightedCompute);
|
||||
for (SpatCompute spatCompute : region.getOps<SpatCompute>())
|
||||
spatComputes.push_back(spatCompute);
|
||||
|
||||
for (auto [indexEndEdge, spatWeightedCompute] : llvm::enumerate(spatWeightedComputes)) {
|
||||
for (Value input : spatWeightedCompute.getInputs()) {
|
||||
if (auto producerCompute = getOriginalSpatWeightedCompute(input.getDefiningOp())) {
|
||||
auto producerIt = llvm::find(spatWeightedComputes, producerCompute);
|
||||
assert(producerIt != spatWeightedComputes.end());
|
||||
auto indexStartEdge = std::distance(spatWeightedComputes.begin(), producerIt);
|
||||
ResultRange outputs = producerCompute.getResults();
|
||||
int64_t totalSize = 0;
|
||||
for (auto output : outputs) {
|
||||
ShapedType resultType = cast<ShapedType>(output.getType());
|
||||
totalSize += getSizeInBytes(resultType);
|
||||
}
|
||||
edges.push_back({indexStartEdge, indexEndEdge, totalSize});
|
||||
for (auto [indexEndEdge, spatCompute] : llvm::enumerate(spatComputes)) {
|
||||
for (Value input : spatCompute.getInputs()) {
|
||||
if (auto producerCompute = getOriginalSpatCompute(input.getDefiningOp())) {
|
||||
auto producerIt = llvm::find(spatComputes, producerCompute);
|
||||
assert(producerIt != spatComputes.end());
|
||||
auto indexStartEdge = std::distance(spatComputes.begin(), producerIt);
|
||||
edges.push_back({indexStartEdge, indexEndEdge, getSizeInBytes(cast<ShapedType>(input.getType()))});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (dcpCriticalWindowSize.getValue() == 0)
|
||||
return runLegacyDcp(spatWeightedComputes, edges, entryOp->getContext());
|
||||
return runLegacyDcp(spatComputes, edges, entryOp->getContext());
|
||||
|
||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(spatWeightedComputes, edges);
|
||||
VirtualGraph virtualGraph = buildInitialVirtualGraph(spatComputes, edges);
|
||||
std::set<std::vector<size_t>> seenCriticalWindows;
|
||||
while (virtualGraph.nodes.size() > 1) {
|
||||
TimingInfo timing = computeTiming(virtualGraph);
|
||||
@@ -446,7 +434,7 @@ DCPAnalysisResult DCPAnalysis::run() {
|
||||
break;
|
||||
}
|
||||
|
||||
return buildResultFromVirtualGraph(virtualGraph, spatWeightedComputes, edges);
|
||||
return buildResultFromVirtualGraph(virtualGraph, spatComputes, edges);
|
||||
}
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
@@ -10,10 +10,10 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
struct DCPAnalysisResult {
|
||||
std::vector<onnx_mlir::spatial::SpatWeightedCompute> dominanceOrderCompute;
|
||||
llvm::DenseMap<onnx_mlir::spatial::SpatWeightedCompute, size_t> computeToCpuMap;
|
||||
llvm::DenseSet<onnx_mlir::spatial::SpatWeightedCompute> isLastComputeOfCpu;
|
||||
llvm::DenseMap<size_t, onnx_mlir::spatial::SpatWeightedCompute> cpuToLastComputeMap;
|
||||
std::vector<onnx_mlir::spatial::SpatCompute> dominanceOrderCompute;
|
||||
llvm::DenseMap<onnx_mlir::spatial::SpatCompute, size_t> computeToCpuMap;
|
||||
llvm::DenseSet<onnx_mlir::spatial::SpatCompute> isLastComputeOfCpu;
|
||||
llvm::DenseMap<size_t, onnx_mlir::spatial::SpatCompute> cpuToLastComputeMap;
|
||||
};
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
@@ -1260,7 +1260,7 @@ DCPAnalysisResult GraphDCP::getResult() {
|
||||
auto dominanceOrder = dcp_graph::collectDominanceOrder(getRoots(), nodes.size());
|
||||
ret.dominanceOrderCompute.reserve(dominanceOrder.size());
|
||||
for (auto elem : dominanceOrder)
|
||||
ret.dominanceOrderCompute.push_back(elem->getSpatWeightedCompute());
|
||||
ret.dominanceOrderCompute.push_back(elem->getSpatCompute());
|
||||
|
||||
for (CPU cpu = 0; cpu < getLastCpu(); ++cpu) {
|
||||
const CpuTaskList* tasks = findCpuTasks(cpu);
|
||||
@@ -1268,10 +1268,10 @@ DCPAnalysisResult GraphDCP::getResult() {
|
||||
continue;
|
||||
size_t i = 0;
|
||||
for (auto node : *tasks) {
|
||||
ret.computeToCpuMap[node->getSpatWeightedCompute()] = cpu;
|
||||
ret.computeToCpuMap[node->getSpatCompute()] = cpu;
|
||||
if (i++ == tasks->size() - 1) {
|
||||
ret.isLastComputeOfCpu.insert(node->getSpatWeightedCompute());
|
||||
ret.cpuToLastComputeMap[cpu] = node->getSpatWeightedCompute();
|
||||
ret.isLastComputeOfCpu.insert(node->getSpatCompute());
|
||||
ret.cpuToLastComputeMap[cpu] = node->getSpatCompute();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,11 +115,11 @@ private:
|
||||
|
||||
public:
|
||||
void runDcp();
|
||||
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatWeightedCompute> spatWeightedComputes,
|
||||
GraphDCP(llvm::ArrayRef<onnx_mlir::spatial::SpatCompute> spatComputes,
|
||||
llvm::ArrayRef<IndexedEdge> edges)
|
||||
: nodes(), cpuTasks(), cpuCrossbarUsage() {
|
||||
for (auto spatWeightedCompute : spatWeightedComputes)
|
||||
nodes.emplace_back(spatWeightedCompute);
|
||||
for (auto spatCompute : spatComputes)
|
||||
nodes.emplace_back(spatCompute);
|
||||
for (auto [start, end, weight] : edges)
|
||||
makeEdge(start, end, weight);
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
class TaskDCP : public onnx_mlir::LabeledListNode<TaskDCP> {
|
||||
onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute;
|
||||
onnx_mlir::spatial::SpatCompute spatCompute;
|
||||
Time aest;
|
||||
Time alst;
|
||||
std::optional<CPU> scheduledCpu;
|
||||
@@ -38,22 +38,22 @@ public:
|
||||
std::vector<Edge> parents;
|
||||
std::vector<Edge> children;
|
||||
TaskDCP() = default;
|
||||
TaskDCP(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute)
|
||||
TaskDCP(onnx_mlir::spatial::SpatCompute spatCompute)
|
||||
: onnx_mlir::LabeledListNode<TaskDCP>(),
|
||||
spatWeightedCompute(spatWeightedCompute),
|
||||
spatCompute(spatCompute),
|
||||
aest(0),
|
||||
alst(0),
|
||||
scheduledCpu(),
|
||||
weight(getSpatComputeWeight(spatWeightedCompute)),
|
||||
weight(getSpatComputeWeight(spatCompute)),
|
||||
baseWeight(weight),
|
||||
crossbarUsage(getSpatComputeCrossbarUsage(spatWeightedCompute)),
|
||||
crossbarUsage(getSpatComputeCrossbarUsage(spatCompute)),
|
||||
syntheticId(-1),
|
||||
parents(),
|
||||
children() {}
|
||||
|
||||
TaskDCP(int64_t id, Weight weight, CrossbarUsage crossbarUsage = 0)
|
||||
: onnx_mlir::LabeledListNode<TaskDCP>(),
|
||||
spatWeightedCompute(),
|
||||
spatCompute(),
|
||||
aest(0),
|
||||
alst(0),
|
||||
scheduledCpu(),
|
||||
@@ -90,14 +90,14 @@ public:
|
||||
void setAlst(Time value) { alst = value; }
|
||||
bool hasDescendant(TaskDCP* child);
|
||||
int64_t Id() const {
|
||||
if (spatWeightedCompute)
|
||||
return reinterpret_cast<int64_t>(spatWeightedCompute.getAsOpaquePointer());
|
||||
if (spatCompute)
|
||||
return reinterpret_cast<int64_t>(spatCompute.getAsOpaquePointer());
|
||||
return syntheticId;
|
||||
}
|
||||
|
||||
bool isCriticalPath() const { return alst == aest; }
|
||||
bool isScheduled() const { return scheduledCpu.has_value(); }
|
||||
onnx_mlir::spatial::SpatWeightedCompute getSpatWeightedCompute() const { return spatWeightedCompute; }
|
||||
onnx_mlir::spatial::SpatCompute getSpatCompute() const { return spatCompute; }
|
||||
|
||||
void setFlag(long long val) { flag = val; }
|
||||
long long getFlag() const { return flag; }
|
||||
|
||||
@@ -92,18 +92,18 @@ inline T subtractOrZero(T lhs, T rhs) {
|
||||
|
||||
inline Time slackOrZero(Time earliestStart, Time latestStart) { return subtractOrZero(latestStart, earliestStart); }
|
||||
|
||||
inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) {
|
||||
inline Weight getSpatComputeWeight(onnx_mlir::spatial::SpatCompute spatCompute) {
|
||||
constexpr Weight kOperationWeight = 100;
|
||||
Weight numOperations = 0;
|
||||
for (auto& block : spatWeightedCompute.getBody())
|
||||
for (auto& block : spatCompute.getBody())
|
||||
for ([[maybe_unused]] auto& op : block)
|
||||
numOperations = checkedAdd(numOperations, static_cast<Weight>(1));
|
||||
return checkedMultiply(numOperations, kOperationWeight);
|
||||
}
|
||||
|
||||
inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatWeightedCompute spatWeightedCompute) {
|
||||
inline CrossbarUsage getSpatComputeCrossbarUsage(onnx_mlir::spatial::SpatCompute spatCompute) {
|
||||
CrossbarUsage crossbarUsage = 0;
|
||||
for (auto& region : spatWeightedCompute.getBody())
|
||||
for (auto& region : spatCompute.getBody())
|
||||
for (auto& inst : region)
|
||||
if (llvm::isa<onnx_mlir::spatial::SpatWeightedVMMOp>(inst))
|
||||
crossbarUsage = checkedAdd(crossbarUsage, static_cast<CrossbarUsage>(1));
|
||||
|
||||
@@ -24,30 +24,29 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
using SpatWeightedCompute = spatial::SpatWeightedCompute;
|
||||
using SpatCompute = spatial::SpatCompute;
|
||||
|
||||
struct ComputeValueResults {
|
||||
// Value yielded by the yieldOp
|
||||
Value innerValue;
|
||||
SmallVector<Value> innerValues;
|
||||
|
||||
Value get(size_t resultIndex) const {
|
||||
assert(resultIndex < innerValues.size() && "compute result index out of range");
|
||||
return innerValues[resultIndex];
|
||||
}
|
||||
};
|
||||
|
||||
class LazyInsertComputeResult {
|
||||
using InsertPoint = mlir::IRRewriter::InsertPoint;
|
||||
ComputeValueResults computeResults;
|
||||
Value channelValue;
|
||||
bool onlyChannel;
|
||||
std::function<void(InsertPoint insertPoint)> channelSendInserter;
|
||||
InsertPoint sendInsertPoint;
|
||||
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter;
|
||||
std::function<std::pair<Value, std::function<void(InsertPoint)>>(size_t)> channelNewInserter;
|
||||
|
||||
public:
|
||||
LazyInsertComputeResult(ComputeValueResults computeValueResults,
|
||||
std::function<std::pair<Value, std::function<void(InsertPoint)>>()> channelNewInserter,
|
||||
std::function<std::pair<Value, std::function<void(InsertPoint)>>(size_t)> channelNewInserter,
|
||||
bool isOnlyChannel)
|
||||
: computeResults(computeValueResults),
|
||||
onlyChannel(isOnlyChannel),
|
||||
channelSendInserter(nullptr),
|
||||
sendInsertPoint({}),
|
||||
channelNewInserter(channelNewInserter) {}
|
||||
|
||||
struct ChannelOrLocalOp {
|
||||
@@ -57,12 +56,12 @@ public:
|
||||
|
||||
bool onlyChanneled() const { return onlyChannel; }
|
||||
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatWeightedCompute currentCompute) {
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender(SpatCompute currentCompute, size_t resultIndex) {
|
||||
Value innerValue = computeResults.get(resultIndex);
|
||||
|
||||
auto [newChannelValue, senderInserter] = channelNewInserter();
|
||||
channelValue = newChannelValue;
|
||||
channelSendInserter = senderInserter;
|
||||
auto* block = computeResults.innerValue.getParentBlock();
|
||||
auto [channelValue, channelSendInserter] = channelNewInserter(resultIndex);
|
||||
InsertPoint sendInsertPoint;
|
||||
auto* block = innerValue.getParentBlock();
|
||||
if (!block->empty() && isa<spatial::SpatYieldOp>(block->back()))
|
||||
sendInsertPoint = InsertPoint(block, --block->end());
|
||||
else
|
||||
@@ -70,28 +69,30 @@ public:
|
||||
if (currentCompute) {
|
||||
for (auto& block : currentCompute.getBody())
|
||||
if (&block == sendInsertPoint.getBlock())
|
||||
return {computeResults.innerValue, false};
|
||||
return {innerValue, false};
|
||||
}
|
||||
channelSendInserter(sendInsertPoint);
|
||||
return {channelValue, true};
|
||||
}
|
||||
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender() { return getAsChannelValueAndInsertSender({}); }
|
||||
ChannelOrLocalOp getAsChannelValueAndInsertSender(size_t resultIndex) {
|
||||
return getAsChannelValueAndInsertSender({}, resultIndex);
|
||||
}
|
||||
};
|
||||
|
||||
struct MergeComputeNodesPass : PassWrapper<MergeComputeNodesPass, OperationPass<func::FuncOp>> {
|
||||
|
||||
private:
|
||||
DenseMap<SpatWeightedCompute, LazyInsertComputeResult> newComputeNodeResults;
|
||||
DenseMap<SpatWeightedCompute, SpatWeightedCompute> oldToNewComputeMap;
|
||||
DenseMap<int64_t, SpatWeightedCompute> cpuToNewComputeMap;
|
||||
DenseMap<SpatCompute, LazyInsertComputeResult> newComputeNodeResults;
|
||||
DenseMap<SpatCompute, SpatCompute> oldToNewComputeMap;
|
||||
DenseMap<int64_t, SpatCompute> cpuToNewComputeMap;
|
||||
|
||||
public:
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeComputeNodesPass)
|
||||
|
||||
StringRef getArgument() const override { return "pim-merge-compute-nodes-pass"; }
|
||||
StringRef getDescription() const override {
|
||||
return "Merge Spatial-Weighted-Compute-Nodes in order to reduce the total "
|
||||
return "Merge Spatial-Compute-Nodes in order to reduce the total "
|
||||
"execution time";
|
||||
}
|
||||
|
||||
@@ -105,22 +106,22 @@ public:
|
||||
for (auto currentComputeNode : analysisResult.dominanceOrderCompute) {
|
||||
size_t cpu = analysisResult.computeToCpuMap.at(currentComputeNode);
|
||||
if (!cpuToNewComputeMap.contains(cpu)) {
|
||||
ValueTypeRange<ResultRange> newWeightedComputeType = cpuToLastComputeMap.at(cpu).getResultTypes();
|
||||
auto [newWeightedCompute, computeValueResult] = createNewComputeNode(
|
||||
currentComputeNode, newWeightedComputeType, lastComputeOfCpu.contains(currentComputeNode));
|
||||
cpuToNewComputeMap[cpu] = newWeightedCompute;
|
||||
ValueTypeRange<ResultRange> newComputeType = cpuToLastComputeMap.at(cpu).getResultTypes();
|
||||
auto [newCompute, computeValueResult] = createNewComputeNode(
|
||||
currentComputeNode, newComputeType, lastComputeOfCpu.contains(currentComputeNode));
|
||||
cpuToNewComputeMap[cpu] = newCompute;
|
||||
newComputeNodeResults.insert(
|
||||
std::make_pair(currentComputeNode,
|
||||
createLazyComputeResult(
|
||||
newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
||||
newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
||||
}
|
||||
else {
|
||||
auto [newWeightedCompute, computeValueResult] = mergeIntoComputeNode(
|
||||
auto [newCompute, computeValueResult] = mergeIntoComputeNode(
|
||||
cpuToNewComputeMap[cpu], currentComputeNode, lastComputeOfCpu.contains(currentComputeNode));
|
||||
newComputeNodeResults.insert(
|
||||
std::make_pair(currentComputeNode,
|
||||
createLazyComputeResult(
|
||||
newWeightedCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
||||
newCompute, computeValueResult, lastComputeOfCpu.contains(currentComputeNode))));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,8 +135,8 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
std::pair<SpatWeightedCompute, ComputeValueResults> createNewComputeNode(
|
||||
SpatWeightedCompute oldWeightedCompute, ValueTypeRange<ResultRange> newWeightedComputeType, bool lastCompute) {
|
||||
std::pair<SpatCompute, ComputeValueResults> createNewComputeNode(
|
||||
SpatCompute oldCompute, ValueTypeRange<ResultRange> newComputeType, bool lastCompute) {
|
||||
func::FuncOp func = getOperation();
|
||||
auto loc = func.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
@@ -148,50 +149,53 @@ private:
|
||||
llvm::SmallVector<Type> newBBOperandType;
|
||||
llvm::SmallVector<Location> newBBLocations;
|
||||
|
||||
for (auto arg : oldWeightedCompute.getWeights())
|
||||
for (auto arg : oldCompute.getWeights())
|
||||
newComputeOperand.push_back(arg);
|
||||
|
||||
for (auto arg : oldWeightedCompute.getInputs())
|
||||
if (!llvm::isa_and_present<SpatWeightedCompute>(arg.getDefiningOp())) {
|
||||
for (auto arg : oldCompute.getInputs())
|
||||
if (!llvm::isa_and_present<SpatCompute>(arg.getDefiningOp())) {
|
||||
newComputeOperand.push_back(arg);
|
||||
newBBOperandType.push_back(arg.getType());
|
||||
newBBLocations.push_back(loc);
|
||||
}
|
||||
|
||||
auto newWeightedCompute = SpatWeightedCompute::create(rewriter, loc, newWeightedComputeType, newComputeOperand);
|
||||
auto newCompute = SpatCompute::create(rewriter, loc, newComputeType, newComputeOperand);
|
||||
|
||||
rewriter.createBlock(
|
||||
&newWeightedCompute.getBody(), newWeightedCompute.getBody().end(), newBBOperandType, newBBLocations);
|
||||
newWeightedCompute.getProperties().setOperandSegmentSizes(
|
||||
{(int) oldWeightedCompute.getWeights().size(), (int) newBBOperandType.size()});
|
||||
&newCompute.getBody(), newCompute.getBody().end(), newBBOperandType, newBBLocations);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{(int) oldCompute.getWeights().size(), (int) newBBOperandType.size()});
|
||||
|
||||
auto& newBB = newWeightedCompute.getBody().front();
|
||||
auto& oldBB = oldWeightedCompute.getBody().front();
|
||||
auto& newBB = newCompute.getBody().front();
|
||||
auto& oldBB = oldCompute.getBody().front();
|
||||
rewriter.setInsertionPointToEnd(&newBB);
|
||||
|
||||
int indexNew = 0;
|
||||
size_t indexOld = oldWeightedCompute.getWeights().size();
|
||||
size_t indexOldStart = oldWeightedCompute.getWeights().size();
|
||||
for (; indexOld < oldWeightedCompute.getNumOperands(); ++indexOld) {
|
||||
if (!llvm::isa_and_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp())) {
|
||||
size_t indexOld = oldCompute.getWeights().size();
|
||||
size_t indexOldStart = oldCompute.getWeights().size();
|
||||
for (; indexOld < oldCompute.getNumOperands(); ++indexOld) {
|
||||
if (!llvm::isa_and_present<SpatCompute>(oldCompute.getOperand(indexOld).getDefiningOp())) {
|
||||
mapper.map(oldBB.getArgument(indexOld - indexOldStart), newBB.getArgument(indexNew++));
|
||||
}
|
||||
else {
|
||||
auto argWeightCompute =
|
||||
llvm::dyn_cast_if_present<SpatWeightedCompute>(oldWeightedCompute.getOperand(indexOld).getDefiningOp());
|
||||
llvm::dyn_cast_if_present<SpatCompute>(oldCompute.getOperand(indexOld).getDefiningOp());
|
||||
auto argResultIndex = cast<OpResult>(oldCompute.getOperand(indexOld)).getResultNumber();
|
||||
|
||||
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
|
||||
auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender();
|
||||
auto [channelVal, isChannel] = lazyArgWeight.getAsChannelValueAndInsertSender(argResultIndex);
|
||||
assert(isChannel == true);
|
||||
spatial::SpatChannelReceiveOp receiveOp =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelVal);
|
||||
spatial::SpatChannelReceiveOp receiveOp = spatial::SpatChannelReceiveOp::create(
|
||||
rewriter, loc, oldCompute.getOperand(indexOld).getType(), channelVal);
|
||||
mapper.map(oldBB.getArgument(indexOld - indexOldStart), receiveOp);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& op : oldWeightedCompute.getOps()) {
|
||||
for (auto& op : oldCompute.getOps()) {
|
||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0));
|
||||
computeValueResults.innerValues.reserve(yield.getNumOperands());
|
||||
for (Value yieldOperand : yield.getOperands())
|
||||
computeValueResults.innerValues.push_back(mapper.lookup(yieldOperand));
|
||||
if (lastCompute)
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
@@ -199,16 +203,18 @@ private:
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
|
||||
for (auto& use : llvm::make_early_inc_range(oldWeightedCompute->getUses()))
|
||||
if (isa<func::ReturnOp>(use.getOwner()))
|
||||
use.assign(newWeightedCompute.getResult(0));
|
||||
for (auto& use : llvm::make_early_inc_range(oldCompute->getUses()))
|
||||
if (isa<func::ReturnOp>(use.getOwner())) {
|
||||
auto resultIndex = cast<OpResult>(use.get()).getResultNumber();
|
||||
use.assign(newCompute.getResult(resultIndex));
|
||||
}
|
||||
|
||||
oldToNewComputeMap.insert({oldWeightedCompute, newWeightedCompute});
|
||||
return {cast<SpatWeightedCompute>(newWeightedCompute), computeValueResults};
|
||||
oldToNewComputeMap.insert({oldCompute, newCompute});
|
||||
return {cast<SpatCompute>(newCompute), computeValueResults};
|
||||
}
|
||||
|
||||
std::pair<SpatWeightedCompute, ComputeValueResults>
|
||||
mergeIntoComputeNode(SpatWeightedCompute toCompute, SpatWeightedCompute fromCompute, bool lastCompute) {
|
||||
std::pair<SpatCompute, ComputeValueResults>
|
||||
mergeIntoComputeNode(SpatCompute toCompute, SpatCompute fromCompute, bool lastCompute) {
|
||||
func::FuncOp func = getOperation();
|
||||
auto loc = func.getLoc();
|
||||
IRRewriter rewriter(&getContext());
|
||||
@@ -239,14 +245,15 @@ private:
|
||||
// Insert receiveOp
|
||||
rewriter.setInsertionPointToEnd(&toBB);
|
||||
for (auto [bbIndex, arg] : llvm::enumerate(fromCompute.getInputs())) {
|
||||
if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatWeightedCompute>(arg.getDefiningOp())) {
|
||||
if (auto argWeightCompute = llvm::dyn_cast_if_present<SpatCompute>(arg.getDefiningOp())) {
|
||||
LazyInsertComputeResult& lazyArgWeight = newComputeNodeResults.at(argWeightCompute);
|
||||
auto argResultIndex = cast<OpResult>(arg).getResultNumber();
|
||||
|
||||
LazyInsertComputeResult::ChannelOrLocalOp channelOrLocal =
|
||||
lazyArgWeight.getAsChannelValueAndInsertSender(toCompute);
|
||||
lazyArgWeight.getAsChannelValueAndInsertSender(toCompute, argResultIndex);
|
||||
if (channelOrLocal.isChannel) {
|
||||
spatial::SpatChannelReceiveOp receiveOp =
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, argWeightCompute.getType(0), channelOrLocal.data);
|
||||
spatial::SpatChannelReceiveOp::create(rewriter, loc, arg.getType(), channelOrLocal.data);
|
||||
mapper.map(fromBB.getArgument(bbIndex), receiveOp.getResult());
|
||||
}
|
||||
else {
|
||||
@@ -286,7 +293,9 @@ private:
|
||||
};
|
||||
for (auto& op : fromCompute.getOps()) {
|
||||
if (auto yield = dyn_cast<spatial::SpatYieldOp>(&op)) {
|
||||
computeValueResults.innerValue = mapper.lookup(yield.getOperand(0));
|
||||
computeValueResults.innerValues.reserve(yield.getNumOperands());
|
||||
for (Value yieldOperand : yield.getOperands())
|
||||
computeValueResults.innerValues.push_back(mapper.lookup(yieldOperand));
|
||||
if (lastCompute)
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
@@ -299,33 +308,36 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
for (auto users : fromCompute->getUsers())
|
||||
if (auto funcRet = dyn_cast<func::ReturnOp>(users))
|
||||
funcRet.setOperand(0, toCompute.getResult(0));
|
||||
for (auto& use : llvm::make_early_inc_range(fromCompute->getUses()))
|
||||
if (isa<func::ReturnOp>(use.getOwner())) {
|
||||
auto resultIndex = cast<OpResult>(use.get()).getResultNumber();
|
||||
use.assign(toCompute.getResult(resultIndex));
|
||||
}
|
||||
|
||||
oldToNewComputeMap.insert({fromCompute, toCompute});
|
||||
return {cast<SpatWeightedCompute>(toCompute), computeValueResults};
|
||||
return {cast<SpatCompute>(toCompute), computeValueResults};
|
||||
}
|
||||
|
||||
LazyInsertComputeResult createLazyComputeResult(SpatWeightedCompute weightedCompute,
|
||||
LazyInsertComputeResult createLazyComputeResult(SpatCompute compute,
|
||||
ComputeValueResults computeValueResults,
|
||||
bool lastCompute) {
|
||||
func::FuncOp funcOp = cast<func::FuncOp>(weightedCompute->getParentOp());
|
||||
func::FuncOp funcOp = cast<func::FuncOp>(compute->getParentOp());
|
||||
auto* context = &getContext();
|
||||
auto loc = funcOp.getLoc();
|
||||
IRRewriter rewriter(context);
|
||||
|
||||
rewriter.setInsertionPointToStart(&funcOp.front());
|
||||
auto savedChannelInsertPoint = rewriter.saveInsertionPoint();
|
||||
auto insertNew = [savedChannelInsertPoint, context, loc, computeValueResults]() {
|
||||
auto insertNew = [savedChannelInsertPoint, context, loc, computeValueResults](size_t resultIndex) {
|
||||
IRRewriter rewriter(context);
|
||||
rewriter.restoreInsertionPoint(savedChannelInsertPoint);
|
||||
auto channelOp = spatial::SpatChannelNewOp::create(rewriter, loc, spatial::SpatChannelType::get(context));
|
||||
auto channelVal = channelOp.getResult();
|
||||
auto insertVal = [&context, loc, computeValueResults, channelVal](mlir::IRRewriter::InsertPoint sendInsertPoint) {
|
||||
auto insertVal =
|
||||
[&context, loc, computeValueResults, channelVal, resultIndex](mlir::IRRewriter::InsertPoint sendInsertPoint) {
|
||||
IRRewriter rewriter(context);
|
||||
rewriter.restoreInsertionPoint(sendInsertPoint);
|
||||
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.innerValue);
|
||||
auto spatSend = spatial::SpatChannelSendOp::create(rewriter, loc, channelVal, computeValueResults.get(resultIndex));
|
||||
return spatSend;
|
||||
};
|
||||
std::pair<Value, std::function<void(mlir::IRRewriter::InsertPoint)>> ret {channelVal, insertVal};
|
||||
|
||||
Reference in New Issue
Block a user