Add register reuse + peft scheduler cost model + Useless merger

This commit is contained in:
ilgeco
2026-06-18 10:56:57 +02:00
parent 852bef7605
commit e083c27d80
13 changed files with 350 additions and 20 deletions
+14
View File
@@ -17,6 +17,20 @@ std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRe
std::fstream openReportFile(const std::string& name) { return openReportFileWithExtension(name, "txt"); }
std::fstream openAppendedReportFileWithExtension(const std::string& name, llvm::StringRef extension) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return {};
std::string reportsDir = outputDir + "/reports";
createDirectory(reportsDir);
return std::fstream(reportsDir + "/" + name + "." + extension.str(), std::ios::out | std::ios::app);
}
std::fstream openAppendedReportFile(const std::string& name) {
return openAppendedReportFileWithExtension(name, "txt");
}
std::string formatReportMemory(uint64_t bytes) {
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
int i = 0;
+2
View File
@@ -12,6 +12,8 @@ namespace onnx_mlir {
std::fstream openReportFile(const std::string& name);
std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRef extension);
std::fstream openAppendedReportFile(const std::string& name);
std::fstream openAppendedReportFileWithExtension(const std::string& name, llvm::StringRef extension);
std::string formatReportMemory(uint64_t bytes);
struct ReportField {
+26 -2
View File
@@ -588,13 +588,37 @@ void PimCodeGen::emitInstruction(const pim_binary::InstructionRecord& instructio
++emittedInstructionCount;
if (coreJsonStream)
*coreJsonStream << json::Value(pim_binary::makeInstructionJson(instruction)) << ',';
updateScalarRegisterCache(instruction);
}
void PimCodeGen::updateScalarRegisterCache(const pim_binary::InstructionRecord& instruction) const {
switch (instruction.opcode) {
case pim_binary::Opcode::sldi:
scalarRegisterValues[instruction.rd] = instruction.r2OrImm;
break;
case pim_binary::Opcode::sld:
case pim_binary::Opcode::sadd:
case pim_binary::Opcode::ssub:
case pim_binary::Opcode::smul:
case pim_binary::Opcode::saddi:
case pim_binary::Opcode::smuli:
scalarRegisterValues[instruction.rd].reset();
break;
default:
break;
}
}
void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const {
auto registerIndex = pim::checkedU8OrCrash(registerNumber, "register number");
auto immediateValue = pim::checkedI32OrCrash(immediate, "register immediate");
if (scalarRegisterValues[registerIndex] == immediateValue)
return;
pim_binary::InstructionRecord instruction;
instruction.opcode = pim_binary::Opcode::sldi;
instruction.rd = static_cast<uint8_t>(registerNumber);
instruction.r2OrImm = pim::checkedI32OrCrash(immediate, "register immediate");
instruction.rd = registerIndex;
instruction.r2OrImm = immediateValue;
emitInstruction(instruction);
}
+3
View File
@@ -9,6 +9,7 @@
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_os_ostream.h"
#include <array>
#include <fstream>
#include <limits>
#include <optional>
@@ -170,6 +171,7 @@ class PimCodeGen {
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
std::optional<unsigned> batchLane;
mutable uint32_t emittedInstructionCount = 0;
mutable std::array<std::optional<int32_t>, 256> scalarRegisterValues = {};
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getValueAddress(value, knowledge, batchLane);
@@ -177,6 +179,7 @@ class PimCodeGen {
size_t remapCoreId(size_t coreId) const;
void emitInstruction(const pim_binary::InstructionRecord& instruction) const;
void updateScalarRegisterCache(const pim_binary::InstructionRecord& instruction) const;
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
void setupRd(size_t rdAddress, size_t rdOffset) const;
+42
View File
@@ -32,6 +32,31 @@ llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport(
llvm::cl::init(PimMemoryReportNone),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimConvLoweringType> pimConvLowering(
"pim-conv-lowering",
llvm::cl::desc("Convolution lowering strategy for PIM"),
llvm::cl::values(clEnumValN(PimConvLoweringAuto, "auto", "Select the Conv lowering strategy automatically")),
llvm::cl::values(clEnumValN(PimConvLoweringLegacy, "legacy", "Use the legacy explicit-im2col Conv lowering")),
llvm::cl::values(clEnumValN(PimConvLoweringDepthwise, "depthwise", "Force the depthwise-specialized Conv lowering")),
llvm::cl::values(
clEnumValN(PimConvLoweringPackedIm2Col, "packed-im2col", "Use explicit im2col with packed multi-position GEMM")),
llvm::cl::values(clEnumValN(PimConvLoweringStreamedPatch,
"streamed-patch",
"Use streamed/chunked im2col rows without multi-position packing")),
llvm::cl::values(clEnumValN(PimConvLoweringStreamedPacked,
"streamed-packed",
"Use streamed/chunked im2col rows with packed multi-position GEMM")),
llvm::cl::values(clEnumValN(PimConvLoweringOutputChannelTiled,
"output-channel-tiled",
"Force Conv lowering that relies on Gemm output-channel tiling")),
llvm::cl::values(
clEnumValN(PimConvLoweringInputKTiled, "input-k-tiled", "Force Conv lowering that relies on Gemm K tiling")),
llvm::cl::values(clEnumValN(PimConvLoweringTiled2D,
"tiled-2d",
"Force Conv lowering that relies on Gemm 2D K/C tiling")),
llvm::cl::init(PimConvLoweringAuto),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool>
pimOnlyCodegen("pim-only-codegen",
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
@@ -49,6 +74,23 @@ llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl",
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<uint64_t> pimConvIm2colMaxElements(
"pim-conv-im2col-max-elements",
llvm::cl::desc("Maximum number of im2col elements to materialize globally for one Conv before streaming/chunking"),
llvm::cl::init(1ull << 20),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<uint64_t> pimConvStreamChunkPositions(
"pim-conv-stream-chunk-positions",
llvm::cl::desc("Maximum number of Conv output positions to materialize in one streamed chunk"),
llvm::cl::init(1024),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimReportConvLowering("pim-report-conv-lowering",
llvm::cl::desc("Emit a bounded Conv lowering report"),
llvm::cl::init(true),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
llvm::cl::desc("Also emit per-core JSON instruction files alongside binary .pim files"),
llvm::cl::init(false),
+16
View File
@@ -30,19 +30,35 @@ typedef enum {
PimMemoryReportFull = 2,
} PimMemoryReportLevel;
typedef enum {
PimConvLoweringAuto = 0,
PimConvLoweringLegacy = 1,
PimConvLoweringDepthwise = 2,
PimConvLoweringPackedIm2Col = 3,
PimConvLoweringStreamedPatch = 4,
PimConvLoweringStreamedPacked = 5,
PimConvLoweringOutputChannelTiled = 6,
PimConvLoweringInputKTiled = 7,
PimConvLoweringTiled2D = 8,
} PimConvLoweringType;
extern llvm::cl::OptionCategory OnnxMlirOptions;
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
extern llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport;
extern llvm::cl::opt<PimConvLoweringType> pimConvLowering;
extern llvm::cl::opt<bool> pimOnlyCodegen;
extern llvm::cl::opt<bool> pimDisableMemoryCoalescing;
extern llvm::cl::opt<bool> useExperimentalConvImpl;
extern llvm::cl::opt<bool> pimEmitJson;
extern llvm::cl::opt<bool> pimReportConvLowering;
extern llvm::cl::opt<size_t> crossbarSize;
extern llvm::cl::opt<size_t> crossbarCountInCore;
extern llvm::cl::opt<long> coresCount;
extern llvm::cl::opt<uint64_t> pimConvIm2colMaxElements;
extern llvm::cl::opt<uint64_t> pimConvStreamChunkPositions;
bool hasExplicitPimCoreCount();
void verifyExplicitPimCoreCount();
@@ -19,9 +19,11 @@ using namespace mlir;
namespace onnx_mlir {
bool isWeightLikeComputeOperand(Value value) {
static bool isWeightMaterializationValue(Value value, bool requireMatrixShape) {
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
if (!rankedType || !isMatrixShape(rankedType.getShape()))
if (!rankedType)
return false;
if (requireMatrixShape && !isMatrixShape(rankedType.getShape()))
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
@@ -29,8 +31,14 @@ bool isWeightLikeComputeOperand(Value value) {
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp)) {
auto sourceType = dyn_cast<RankedTensorType>(value.getType());
if (!sourceType)
return false;
if (requireMatrixShape && !isMatrixShape(sourceType.getShape()))
return false;
return true;
}
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
@@ -55,6 +63,8 @@ bool isWeightLikeComputeOperand(Value value) {
return false;
}
bool isWeightLikeComputeOperand(Value value) { return isWeightMaterializationValue(value, /*requireMatrixShape=*/true); }
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
if (auto mapped = mapper.lookupOrNull(value))
return cast<Value>(mapped);
@@ -91,7 +101,7 @@ FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr
continue;
}
if (isWeightLikeComputeOperand(operand)) {
if (isWeightMaterializationValue(operand, /*requireMatrixShape=*/false)) {
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
if (failed(clonedOperand))
return failure();
@@ -26,6 +26,82 @@ static bool isUsedOnlyAsExplicitHostOperand(Value value) {
});
}
static bool isMaterializableExternalTensorOp(Operation* op) {
return isa<spatial::SpatChannelReceiveOp,
spatial::SpatExtractRowsOp,
tensor::ExtractSliceOp,
tensor::ExpandShapeOp,
tensor::CollapseShapeOp>(op);
}
//TODO REMOVE THIS UGLY FIX
//TODO: Remove this helper once compute_batch external tensor captures are
// fixed at the producer side.
//
// This function is a temporary SpatialToPim repair path. It clones selected
// external tensor producers, such as channel_receive and tensor view/slice ops,
// into the new pim.core_batch body when the old spat.compute_batch body refers
// to tensor values defined outside the batch.
//
// The real invariant should be stronger:
//
// A spat.compute_batch body must not capture external tensor values.
// Every tensor used inside the body must be either:
// - a compute_batch block argument,
// - defined inside the compute_batch body,
// - or a legal constant-like value.
//
// If this invariant is violated, the responsible producer, most likely merge
// schedule materialization, should emit verifier-clean Spatial IR instead of
// relying on SpatialToPim to clone external producer chains later.
//
// After that producer-side fix:
// 1. remove isMaterializableExternalTensorOp,
// 2. remove materializeExternalTensorValue,
// 3. make lowerComputeBatchOp emit a hard diagnostic for any unmapped external
// tensor operand,
// 4. keep/strengthen the Spatial verifier so the invalid capture is rejected
// before SpatialToPim.
//
// Be careful not to replace every external tensor capture with a normal
// compute_batch input blindly: host-backed tensors and explicit inter-core
// communication have different semantics. In particular, channel_receive-like
// values should be materialized through the communication model, not silently
// treated as host inputs.
static FailureOr<Value> materializeExternalTensorValue(IRRewriter& rewriter,
Location loc,
Block& oldBlock,
Value value,
IRMapping& mapper) {
if (mapper.contains(value))
return mapper.lookup(value);
if (!isa<TensorType>(value.getType()))
return value;
Operation* definingOp = value.getDefiningOp();
if (!definingOp || definingOp->hasTrait<OpTrait::ConstantLike>())
return failure();
if (definingOp->getBlock() == &oldBlock)
return failure();
if (!isMaterializableExternalTensorOp(definingOp))
return failure();
for (Value operand : definingOp->getOperands()) {
FailureOr<Value> materializedOperand = materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper);
if (succeeded(materializedOperand))
mapper.map(operand, *materializedOperand);
}
Operation* cloned = rewriter.clone(*definingOp, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
return mapper.lookup(value);
}
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp,
size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
@@ -264,9 +340,18 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
Operation* definingOp = operand.getDefiningOp();
if (definingOp && definingOp->getBlock() == &oldBlock)
continue;
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
return computeBatchOp.emitOpError(
"expected external tensor communication to be materialized in Spatial before batch lowering");
if (succeeded(materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper)))
continue;
InFlightDiagnostic diagnostic =
computeBatchOp.emitOpError("expected external tensor communication to be materialized in Spatial before batch lowering");
diagnostic << " while cloning nested op '" << op.getName() << "' tensor operand #" << operandIndex;
if (definingOp)
diagnostic << " from external producer '" << definingOp->getName() << "'";
return diagnostic;
}
Operation* cloned = rewriter.clone(op, mapper);
@@ -2018,6 +2018,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
struct PendingProjectedTransferDescriptor {
ProjectedBatchInputKey inputKey;
Operation* extractOp = nullptr;
RankedTensorType sourceType;
RankedTensorType fragmentType;
SmallVector<int64_t, 4> fragmentShape;
SmallVector<SmallVector<SmallVector<int64_t, 4>, 16>, 8> fragmentOffsetsByLane;
@@ -2029,6 +2030,20 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
DenseMap<ProducerKey, DenseMap<ClassId, PendingProjectedTransferDescriptor>, ProducerKeyInfo> pending;
const auto isIdentityProjectedTransfer = [&](const PendingProjectedTransferDescriptor& descriptor) {
if (!descriptor.sourceType || descriptor.sourceType != descriptor.fragmentType)
return false;
if (descriptor.fragmentOffsetsByLane.size() != 1)
return false;
ArrayRef<SmallVector<int64_t, 4>> fragments = descriptor.fragmentOffsetsByLane.front();
if (fragments.size() != 1)
return false;
return llvm::all_of(fragments.front(), [](int64_t offset) { return offset == 0; });
};
const auto appendEvaluatedFragments = [&](PendingProjectedTransferDescriptor& descriptor,
unsigned targetLane,
const AffineProjectedInputSliceMatch& match,
@@ -2117,6 +2132,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
if (descriptor.fragmentOffsetsByLane.empty()) {
descriptor.inputKey = {batch.getOperation(), static_cast<unsigned>(inputIndex)};
descriptor.extractOp = match->extract.getOperation();
descriptor.sourceType = match->sourceType;
descriptor.fragmentType = match->fragmentType;
descriptor.fragmentShape = match->fragmentShape;
descriptor.fragmentOffsetsByLane.resize(targetClass.isBatch ? targetClass.cpus.size() : 1);
@@ -2132,7 +2148,8 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast<unsigned>(inputIndex)};
if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != match->extract.getOperation()
|| descriptor.fragmentType != match->fragmentType || descriptor.fragmentShape != match->fragmentShape
|| descriptor.sourceType != match->sourceType || descriptor.fragmentType != match->fragmentType
|| descriptor.fragmentShape != match->fragmentShape
|| descriptor.loopLowerBounds.size() != match->loops.size()) {
descriptor.invalid = true;
continue;
@@ -2175,6 +2192,8 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
continue;
if (pendingDescriptor.fragmentOffsetsByLane.empty())
continue;
if (isIdentityProjectedTransfer(pendingDescriptor))
continue;
MaterializedClass& targetClass = state.classes[targetClassId];
ProjectedTransferDescriptor descriptor;
@@ -2755,8 +2774,14 @@ FailureOr<ScalarSourceFanoutPlan> buildScalarSourceFanoutPlan(MaterializerState&
if (*descriptor) {
const ProjectedTransferDescriptor& projectedDescriptor = **descriptor;
if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType())
return targetClass.op->emitError("scalar projected receive unexpectedly uses the full producer tensor type");
if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType()) {
if (!fanoutPlan.ordinaryMessages)
fanoutPlan.ordinaryMessages = MessageVector {};
fanoutPlan.ordinaryMessages->append(
receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds);
fanoutPlan.receivePlans.push_back(std::move(receivePlan));
continue;
}
receivePlan.receiveType = projectedDescriptor.payloadType;
receivePlan.projectedExtractOp = projectedDescriptor.extractOp;
@@ -42,7 +42,6 @@ namespace {
using namespace onnx_mlir::compact_asm;
using SpatCompute = spatial::SpatCompute;
using SpatComputeBatch = spatial::SpatComputeBatch;
using spatial::getProducerValueRef;
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
@@ -187,13 +186,23 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
SmallVector<int32_t> coreIds;
};
//TODO Used for report refactor
struct CollectorConcatRow {
uint64_t computeId = 0;
int32_t coreId = -1;
uint64_t operandCount = 0;
};
uint64_t totalComputeOps = 0;
uint64_t totalLogicalComputes = 0;
uint64_t totalBatchComputeOps = 0;
uint64_t totalInstructionCount = 0;
uint64_t totalCrossbarCount = 0;
uint64_t nextBatchId = 0;
//TODO Used for report refactor
std::vector<ReportRow> collectedData;
//TODO Used for report refactor
std::vector<CollectorConcatRow> collectorConcatRows;
auto getPerInstanceCrossbarCount = [&](Operation* op) -> uint64_t {
return static_cast<uint64_t>(spatial::collectDistinctCrossbarWeights(op).size());
@@ -206,7 +215,15 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
SmallVector<int32_t> coreIds;
if (auto coreId = getComputeCoreId(spatCompute))
coreIds.push_back(*coreId);
collectedData.push_back({totalComputeOps++, 1, perInstanceCrossbarCount, numInst, false, coreIds});
uint64_t computeId = totalComputeOps++;
collectedData.push_back({computeId, 1, perInstanceCrossbarCount, numInst, false, coreIds});
uint64_t maxConcatOperands = 0;
spatCompute.getBody().walk([&](spatial::SpatConcatOp concatOp) {
maxConcatOperands = std::max<uint64_t>(maxConcatOperands, concatOp.getInputs().size());
});
//TODO 128 is a magic number
if (maxConcatOperands >= 128 && !coreIds.empty())
collectorConcatRows.push_back({computeId, coreIds.front(), maxConcatOperands});
totalLogicalComputes += 1;
totalInstructionCount += numInst;
totalCrossbarCount += perInstanceCrossbarCount;
@@ -238,9 +255,17 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
{"Number of used crossbars", std::to_string(totalCrossbarCount) }
};
printReportTotalsBlock(os, totalFields);
if (!collectedData.empty())
if (!collectedData.empty() || !collectorConcatRows.empty())
os << "\n";
if (!collectorConcatRows.empty()) {
os << "Collector concat materialization:\n";
for (const CollectorConcatRow& row : collectorConcatRows)
os << "\tmaterialization_kind = single_collector_concat, compute = " << row.computeId
<< ", concat_operand_count = " << row.operandCount << ", collector_core = " << row.coreId << "\n";
os << "\n";
}
sortReportEntriesByFirstCore(collectedData);
for (uint64_t cI = 0; cI < totalComputeOps; ++cI) {
@@ -23,7 +23,10 @@ MergeSchedulerKind getSchedulerKind() {
llvm_unreachable("unknown merge scheduler kind");
}
void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result, unsigned long crossbarCapacity) {
void verifySchedule(const ComputeGraph& graph,
const MergeScheduleResult& result,
unsigned long crossbarCapacity,
size_t processorCount) {
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
@@ -79,7 +82,8 @@ void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].cost);
if (sourceCpu != targetCpu)
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
earliestTargetStart = addOrMax(
earliestTargetStart, getPeftTransferTime(edge.transferCost, sourceCpu, targetCpu, processorCount));
if (targetStart < earliestTargetStart) {
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
graph.nodes[edge.source].originalOrder,
@@ -115,7 +119,10 @@ MergeScheduleResult MergeSchedulingAnalysis::run() {
static_cast<unsigned long>(crossbarCountInCore.getValue()),
entryOp->getContext()});
}
verifySchedule(graph, schedule, static_cast<unsigned long>(crossbarCountInCore.getValue()));
verifySchedule(graph,
schedule,
static_cast<unsigned long>(crossbarCountInCore.getValue()),
options.processorCount);
return schedule;
}
@@ -4,6 +4,7 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <cmath>
#include <limits>
#include <queue>
#include <vector>
@@ -21,6 +22,63 @@ struct ScheduledTask {
Time endTime = 0;
};
struct MeshModel {
size_t rows = 1;
size_t cols = 1;
long double averageDistance = 0.0L;
static MeshModel infer(size_t processorCount) {
MeshModel model;
if (processorCount == 0)
return model;
model.rows = static_cast<size_t>(std::sqrt(static_cast<long double>(processorCount)));
if (model.rows == 0)
model.rows = 1;
while (model.rows > 1 && processorCount % model.rows != 0)
--model.rows;
model.cols = (processorCount + model.rows - 1) / model.rows;
auto averageAxisDistance = [](size_t size) -> long double {
if (size <= 1)
return 0.0L;
return static_cast<long double>(size * size - 1) / (3.0L * static_cast<long double>(size));
};
model.averageDistance = averageAxisDistance(model.rows) + averageAxisDistance(model.cols);
return model;
}
std::pair<size_t, size_t> getCoord(size_t processor) const {
return {processor / cols, processor % cols};
}
size_t getDistance(size_t lhs, size_t rhs) const {
auto [lhsRow, lhsCol] = getCoord(lhs);
auto [rhsRow, rhsCol] = getCoord(rhs);
size_t rowDistance = lhsRow > rhsRow ? lhsRow - rhsRow : rhsRow - lhsRow;
size_t colDistance = lhsCol > rhsCol ? lhsCol - rhsCol : rhsCol - lhsCol;
return rowDistance + colDistance;
}
Time scaleTransferCost(Time transferCost, size_t sourceProcessor, size_t targetProcessor) const {
if (sourceProcessor == targetProcessor || transferCost == 0)
return 0;
long double distance = static_cast<long double>(getDistance(sourceProcessor, targetProcessor));
long double scale = averageDistance > 0.0L ? distance / averageDistance : 1.0L;
scale = std::max(0.25L, scale);
return static_cast<Time>(std::ceil(static_cast<long double>(transferCost) * scale));
}
size_t getCenterDistance(size_t processor) const {
auto [row, col] = getCoord(processor);
size_t centerRow = rows / 2;
size_t centerCol = cols / 2;
size_t rowDistance = row > centerRow ? row - centerRow : centerRow - row;
size_t colDistance = col > centerCol ? col - centerCol : centerCol - col;
return rowDistance + colDistance;
}
};
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
std::queue<size_t> readySinks;
@@ -77,11 +135,16 @@ void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
} // namespace
Time getPeftTransferTime(Time transferCost, size_t sourceProcessor, size_t targetProcessor, size_t processorCount) {
return MeshModel::infer(processorCount).scaleTransferCost(transferCost, sourceProcessor, targetProcessor);
}
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options) {
const size_t nodeCount = graph.nodes.size();
const size_t processorCount = options.processorCount;
if (processorCount == 0)
llvm::report_fatal_error("PEFT scheduler: processor count must be positive");
MeshModel mesh = MeshModel::infer(processorCount);
verifyOctTableSize(nodeCount, processorCount);
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
@@ -89,7 +152,6 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
// MOCK: Replace this with your actual heterogeneous cost lookup.
// If graph.nodes[task] is modified to hold a vector of costs per processor, access it here.
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].cost; };
std::vector<Time> oct(nodeCount * processorCount, 0);
std::vector<Time> minOctPlusComp(nodeCount, 0);
@@ -177,6 +239,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
Time bestEft = 0;
Time bestOeft = std::numeric_limits<Time>::max();
unsigned int bestOverlapCount = 0;
size_t bestCenterDistance = std::numeric_limits<size_t>::max();
bool crossbarRejected = false;
for (size_t processor = 0; processor < processorCount; ++processor) {
@@ -191,7 +254,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
Time dataReady = 0;
for (const auto& [pred, comm] : graph.predecessors[task]) {
const ScheduledTask& predSchedule = schedules[pred];
Time commPenalty = predSchedule.processor == processor ? 0 : comm;
Time commPenalty = getPeftTransferTime(comm, predSchedule.processor, processor, processorCount);
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
}
@@ -218,6 +281,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
Time eft = addOrMax(est, computeCost);
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
size_t centerDistance = mesh.getCenterDistance(processor);
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|| (oeft == bestOeft && eft == bestEft && est < bestEst)) {
@@ -226,13 +290,25 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
bestEft = eft;
bestOeft = oeft;
bestOverlapCount = overlapCount;
bestCenterDistance = centerDistance;
}
else if (oeft == bestOeft && eft == bestEft && est < bestEst && overlapCount < bestOverlapCount) {
else if (oeft == bestOeft && eft == bestEft && est == bestEst
&& centerDistance < bestCenterDistance) {
bestProcessor = processor;
bestEst = est;
bestEft = eft;
bestOeft = oeft;
bestOverlapCount = overlapCount;
bestCenterDistance = centerDistance;
}
else if (oeft == bestOeft && eft == bestEft && est == bestEst
&& centerDistance == bestCenterDistance && overlapCount < bestOverlapCount) {
bestProcessor = processor;
bestEst = est;
bestEft = eft;
bestOeft = oeft;
bestOverlapCount = overlapCount;
bestCenterDistance = centerDistance;
}
}
@@ -14,6 +14,7 @@ struct PeftScheduleOptions {
mlir::MLIRContext* context = nullptr;
};
Time getPeftTransferTime(Time transferCost, size_t sourceProcessor, size_t targetProcessor, size_t processorCount);
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options);
} // namespace spatial