This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -8,7 +11,7 @@ namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
namespace {
|
||||
|
||||
std::optional<BlockArgument> getBatchBodyArgument(Region& body, unsigned argIdx) {
|
||||
std::optional<BlockArgument> getBlockArgument(Region& body, unsigned argIdx) {
|
||||
if (body.empty())
|
||||
return std::nullopt;
|
||||
|
||||
@@ -18,7 +21,7 @@ std::optional<BlockArgument> getBatchBodyArgument(Region& body, unsigned argIdx)
|
||||
return block.getArgument(argIdx);
|
||||
}
|
||||
|
||||
std::optional<BlockArgument> insertBatchBodyArgument(Region& body, unsigned argIdx, Type type, Location loc) {
|
||||
std::optional<BlockArgument> insertBlockArgument(Region& body, unsigned argIdx, Type type, Location loc) {
|
||||
if (body.empty())
|
||||
return std::nullopt;
|
||||
return body.insertArgument(argIdx, type, loc);
|
||||
@@ -34,21 +37,27 @@ void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t i
|
||||
|
||||
} // namespace
|
||||
|
||||
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) {
|
||||
return getBatchBodyArgument(getBody(), idx);
|
||||
}
|
||||
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBlockArgument(getBody(), idx); }
|
||||
|
||||
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
|
||||
return getBatchBodyArgument(getBody(), getWeights().size() + idx);
|
||||
return getBlockArgument(getBody(), getWeights().size() + idx);
|
||||
}
|
||||
|
||||
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
|
||||
llvm::dbgs() << "Disse netanyao\n";
|
||||
auto index = std::distance(getWeights().begin(), existing);
|
||||
return {
|
||||
{*existing, *getWeightArgument(index)}
|
||||
};
|
||||
}
|
||||
|
||||
unsigned weightCount = getWeights().size();
|
||||
unsigned inputCount = getInputs().size();
|
||||
getOperation()->insertOperands(idx, ValueRange {weight});
|
||||
setComputeOperandSegmentSizes(
|
||||
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||
auto blockArg = insertBatchBodyArgument(getBody(), idx, weight.getType(), loc);
|
||||
auto blockArg = insertBlockArgument(getBody(), idx, weight.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
||||
@@ -60,7 +69,7 @@ std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigne
|
||||
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||
setComputeOperandSegmentSizes(
|
||||
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||
auto blockArg = insertBatchBodyArgument(getBody(), weightCount + idx, input.getType(), loc);
|
||||
auto blockArg = insertBlockArgument(getBody(), weightCount + idx, input.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||
@@ -100,28 +109,36 @@ void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn s
|
||||
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
||||
}
|
||||
|
||||
std::optional<BlockArgument> SpatComputeBatch::getLaneArgument() { return getBatchBodyArgument(getBody(), 0); }
|
||||
std::optional<BlockArgument> SpatComputeBatch::getLaneArgument() { return getBlockArgument(getBody(), 0); }
|
||||
|
||||
std::optional<BlockArgument> SpatComputeBatch::getWeightArgument(unsigned idx) {
|
||||
return getBatchBodyArgument(getBody(), 1 + idx);
|
||||
return getBlockArgument(getBody(), 1 + idx);
|
||||
}
|
||||
|
||||
std::optional<BlockArgument> SpatComputeBatch::getInputArgument(unsigned idx) {
|
||||
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + idx);
|
||||
return getBlockArgument(getBody(), 1 + getWeights().size() + idx);
|
||||
}
|
||||
|
||||
std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
|
||||
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
||||
return getBlockArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
||||
}
|
||||
|
||||
std::optional<std::tuple<Value, BlockArgument>>
|
||||
SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||
if (auto existing = llvm::find(getWeights(), weight); existing != getWeights().end()) {
|
||||
auto index = std::distance(getWeights().begin(), existing);
|
||||
llvm::dbgs() << "Bum bum bum bum\n";
|
||||
return {
|
||||
{*existing, *getWeightArgument(index)}
|
||||
};
|
||||
}
|
||||
|
||||
unsigned weightCount = getWeights().size();
|
||||
unsigned inputCount = getInputs().size();
|
||||
getOperation()->insertOperands(idx, ValueRange {weight});
|
||||
setComputeOperandSegmentSizes(
|
||||
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||
auto blockArg = insertBatchBodyArgument(getBody(), 1 + idx, weight.getType(), loc);
|
||||
auto blockArg = insertBlockArgument(getBody(), 1 + idx, weight.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
||||
@@ -133,7 +150,7 @@ std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(un
|
||||
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||
setComputeOperandSegmentSizes(
|
||||
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||
auto blockArg = insertBatchBodyArgument(getBody(), 1 + weightCount + idx, input.getType(), loc);
|
||||
auto blockArg = insertBlockArgument(getBody(), 1 + weightCount + idx, input.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
@@ -177,20 +178,25 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
Time bestEst = 0;
|
||||
Time bestEft = 0;
|
||||
Time bestOeft = std::numeric_limits<Time>::max();
|
||||
unsigned int bestOverlapWeight = 0;
|
||||
bool crossbarRejected = false;
|
||||
|
||||
auto crossbarsAreContainedInProcessor = [&processorCrossbars](mlir::Value nodeCrossbar, size_t processor) {
|
||||
return llvm::is_contained(processorCrossbars[processor], nodeCrossbar);
|
||||
};
|
||||
|
||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||
auto crossbarsAreContained = std::bind(crossbarsAreContainedInProcessor, std::placeholders::_1, processor);
|
||||
if (graph.nodes[task].crossbarUsage.size() != 0
|
||||
&& !llvm::all_of(graph.nodes[task].crossbarUsage,
|
||||
[&processorCrossbars, processor](mlir::Value nodeCrossbar) {
|
||||
return llvm::is_contained(processorCrossbars[processor], nodeCrossbar);
|
||||
})
|
||||
&& !llvm::all_of(graph.nodes[task].crossbarUsage, crossbarsAreContained)
|
||||
&& addOrMax(processorCrossbars[processor].size(), graph.nodes[task].crossbarUsage.size())
|
||||
> options.crossbarCapacity) {
|
||||
|
||||
crossbarRejected = true;
|
||||
continue;
|
||||
}
|
||||
crossbarRejected = true;
|
||||
continue;
|
||||
}
|
||||
unsigned int overlapWeight =
|
||||
llvm::count_if(graph.nodes[task].crossbarUsage, crossbarsAreContained);
|
||||
|
||||
Time dataReady = 0;
|
||||
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
||||
@@ -224,12 +230,19 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
||||
|
||||
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
||||
|| (oeft == bestOeft && eft == bestEft && est < bestEst)
|
||||
|| (oeft == bestOeft && eft == bestEft && est == bestEst && processor < bestProcessor)) {
|
||||
|| (oeft == bestOeft && eft == bestEft && est < bestEst)) {
|
||||
bestProcessor = processor;
|
||||
bestEst = est;
|
||||
bestEft = eft;
|
||||
bestOeft = oeft;
|
||||
bestOverlapWeight = overlapWeight;
|
||||
}
|
||||
else if (oeft == bestOeft && eft == bestEft && est < bestEst && overlapWeight < bestOverlapWeight) {
|
||||
bestProcessor = processor;
|
||||
bestEst = est;
|
||||
bestEft = eft;
|
||||
bestOeft = oeft;
|
||||
bestOverlapWeight = overlapWeight;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user