Add register reuse + peft scheduler cost model + Useless merger
This commit is contained in:
@@ -17,6 +17,20 @@ std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRe
|
|||||||
|
|
||||||
std::fstream openReportFile(const std::string& name) { return openReportFileWithExtension(name, "txt"); }
|
std::fstream openReportFile(const std::string& name) { return openReportFileWithExtension(name, "txt"); }
|
||||||
|
|
||||||
|
std::fstream openAppendedReportFileWithExtension(const std::string& name, llvm::StringRef extension) {
|
||||||
|
std::string outputDir = getOutputDir();
|
||||||
|
if (outputDir.empty())
|
||||||
|
return {};
|
||||||
|
|
||||||
|
std::string reportsDir = outputDir + "/reports";
|
||||||
|
createDirectory(reportsDir);
|
||||||
|
return std::fstream(reportsDir + "/" + name + "." + extension.str(), std::ios::out | std::ios::app);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::fstream openAppendedReportFile(const std::string& name) {
|
||||||
|
return openAppendedReportFileWithExtension(name, "txt");
|
||||||
|
}
|
||||||
|
|
||||||
std::string formatReportMemory(uint64_t bytes) {
|
std::string formatReportMemory(uint64_t bytes) {
|
||||||
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
|
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
|
||||||
int i = 0;
|
int i = 0;
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ namespace onnx_mlir {
|
|||||||
|
|
||||||
std::fstream openReportFile(const std::string& name);
|
std::fstream openReportFile(const std::string& name);
|
||||||
std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRef extension);
|
std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRef extension);
|
||||||
|
std::fstream openAppendedReportFile(const std::string& name);
|
||||||
|
std::fstream openAppendedReportFileWithExtension(const std::string& name, llvm::StringRef extension);
|
||||||
std::string formatReportMemory(uint64_t bytes);
|
std::string formatReportMemory(uint64_t bytes);
|
||||||
|
|
||||||
struct ReportField {
|
struct ReportField {
|
||||||
|
|||||||
@@ -588,13 +588,37 @@ void PimCodeGen::emitInstruction(const pim_binary::InstructionRecord& instructio
|
|||||||
++emittedInstructionCount;
|
++emittedInstructionCount;
|
||||||
if (coreJsonStream)
|
if (coreJsonStream)
|
||||||
*coreJsonStream << json::Value(pim_binary::makeInstructionJson(instruction)) << ',';
|
*coreJsonStream << json::Value(pim_binary::makeInstructionJson(instruction)) << ',';
|
||||||
|
updateScalarRegisterCache(instruction);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PimCodeGen::updateScalarRegisterCache(const pim_binary::InstructionRecord& instruction) const {
|
||||||
|
switch (instruction.opcode) {
|
||||||
|
case pim_binary::Opcode::sldi:
|
||||||
|
scalarRegisterValues[instruction.rd] = instruction.r2OrImm;
|
||||||
|
break;
|
||||||
|
case pim_binary::Opcode::sld:
|
||||||
|
case pim_binary::Opcode::sadd:
|
||||||
|
case pim_binary::Opcode::ssub:
|
||||||
|
case pim_binary::Opcode::smul:
|
||||||
|
case pim_binary::Opcode::saddi:
|
||||||
|
case pim_binary::Opcode::smuli:
|
||||||
|
scalarRegisterValues[instruction.rd].reset();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const {
|
void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const {
|
||||||
|
auto registerIndex = pim::checkedU8OrCrash(registerNumber, "register number");
|
||||||
|
auto immediateValue = pim::checkedI32OrCrash(immediate, "register immediate");
|
||||||
|
if (scalarRegisterValues[registerIndex] == immediateValue)
|
||||||
|
return;
|
||||||
|
|
||||||
pim_binary::InstructionRecord instruction;
|
pim_binary::InstructionRecord instruction;
|
||||||
instruction.opcode = pim_binary::Opcode::sldi;
|
instruction.opcode = pim_binary::Opcode::sldi;
|
||||||
instruction.rd = static_cast<uint8_t>(registerNumber);
|
instruction.rd = registerIndex;
|
||||||
instruction.r2OrImm = pim::checkedI32OrCrash(immediate, "register immediate");
|
instruction.r2OrImm = immediateValue;
|
||||||
emitInstruction(instruction);
|
emitInstruction(instruction);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
#include "llvm/Support/JSON.h"
|
#include "llvm/Support/JSON.h"
|
||||||
#include "llvm/Support/raw_os_ostream.h"
|
#include "llvm/Support/raw_os_ostream.h"
|
||||||
|
|
||||||
|
#include <array>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
@@ -170,6 +171,7 @@ class PimCodeGen {
|
|||||||
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
|
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
|
||||||
std::optional<unsigned> batchLane;
|
std::optional<unsigned> batchLane;
|
||||||
mutable uint32_t emittedInstructionCount = 0;
|
mutable uint32_t emittedInstructionCount = 0;
|
||||||
|
mutable std::array<std::optional<int32_t>, 256> scalarRegisterValues = {};
|
||||||
|
|
||||||
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
||||||
return memory.getValueAddress(value, knowledge, batchLane);
|
return memory.getValueAddress(value, knowledge, batchLane);
|
||||||
@@ -177,6 +179,7 @@ class PimCodeGen {
|
|||||||
size_t remapCoreId(size_t coreId) const;
|
size_t remapCoreId(size_t coreId) const;
|
||||||
|
|
||||||
void emitInstruction(const pim_binary::InstructionRecord& instruction) const;
|
void emitInstruction(const pim_binary::InstructionRecord& instruction) const;
|
||||||
|
void updateScalarRegisterCache(const pim_binary::InstructionRecord& instruction) const;
|
||||||
|
|
||||||
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
|
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
|
||||||
void setupRd(size_t rdAddress, size_t rdOffset) const;
|
void setupRd(size_t rdAddress, size_t rdOffset) const;
|
||||||
|
|||||||
@@ -32,6 +32,31 @@ llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport(
|
|||||||
llvm::cl::init(PimMemoryReportNone),
|
llvm::cl::init(PimMemoryReportNone),
|
||||||
llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<PimConvLoweringType> pimConvLowering(
|
||||||
|
"pim-conv-lowering",
|
||||||
|
llvm::cl::desc("Convolution lowering strategy for PIM"),
|
||||||
|
llvm::cl::values(clEnumValN(PimConvLoweringAuto, "auto", "Select the Conv lowering strategy automatically")),
|
||||||
|
llvm::cl::values(clEnumValN(PimConvLoweringLegacy, "legacy", "Use the legacy explicit-im2col Conv lowering")),
|
||||||
|
llvm::cl::values(clEnumValN(PimConvLoweringDepthwise, "depthwise", "Force the depthwise-specialized Conv lowering")),
|
||||||
|
llvm::cl::values(
|
||||||
|
clEnumValN(PimConvLoweringPackedIm2Col, "packed-im2col", "Use explicit im2col with packed multi-position GEMM")),
|
||||||
|
llvm::cl::values(clEnumValN(PimConvLoweringStreamedPatch,
|
||||||
|
"streamed-patch",
|
||||||
|
"Use streamed/chunked im2col rows without multi-position packing")),
|
||||||
|
llvm::cl::values(clEnumValN(PimConvLoweringStreamedPacked,
|
||||||
|
"streamed-packed",
|
||||||
|
"Use streamed/chunked im2col rows with packed multi-position GEMM")),
|
||||||
|
llvm::cl::values(clEnumValN(PimConvLoweringOutputChannelTiled,
|
||||||
|
"output-channel-tiled",
|
||||||
|
"Force Conv lowering that relies on Gemm output-channel tiling")),
|
||||||
|
llvm::cl::values(
|
||||||
|
clEnumValN(PimConvLoweringInputKTiled, "input-k-tiled", "Force Conv lowering that relies on Gemm K tiling")),
|
||||||
|
llvm::cl::values(clEnumValN(PimConvLoweringTiled2D,
|
||||||
|
"tiled-2d",
|
||||||
|
"Force Conv lowering that relies on Gemm 2D K/C tiling")),
|
||||||
|
llvm::cl::init(PimConvLoweringAuto),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::opt<bool>
|
llvm::cl::opt<bool>
|
||||||
pimOnlyCodegen("pim-only-codegen",
|
pimOnlyCodegen("pim-only-codegen",
|
||||||
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
|
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
|
||||||
@@ -49,6 +74,23 @@ llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl",
|
|||||||
llvm::cl::init(false),
|
llvm::cl::init(false),
|
||||||
llvm::cl::cat(OnnxMlirOptions));
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<uint64_t> pimConvIm2colMaxElements(
|
||||||
|
"pim-conv-im2col-max-elements",
|
||||||
|
llvm::cl::desc("Maximum number of im2col elements to materialize globally for one Conv before streaming/chunking"),
|
||||||
|
llvm::cl::init(1ull << 20),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<uint64_t> pimConvStreamChunkPositions(
|
||||||
|
"pim-conv-stream-chunk-positions",
|
||||||
|
llvm::cl::desc("Maximum number of Conv output positions to materialize in one streamed chunk"),
|
||||||
|
llvm::cl::init(1024),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
|
llvm::cl::opt<bool> pimReportConvLowering("pim-report-conv-lowering",
|
||||||
|
llvm::cl::desc("Emit a bounded Conv lowering report"),
|
||||||
|
llvm::cl::init(true),
|
||||||
|
llvm::cl::cat(OnnxMlirOptions));
|
||||||
|
|
||||||
llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
|
llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
|
||||||
llvm::cl::desc("Also emit per-core JSON instruction files alongside binary .pim files"),
|
llvm::cl::desc("Also emit per-core JSON instruction files alongside binary .pim files"),
|
||||||
llvm::cl::init(false),
|
llvm::cl::init(false),
|
||||||
|
|||||||
@@ -30,19 +30,35 @@ typedef enum {
|
|||||||
PimMemoryReportFull = 2,
|
PimMemoryReportFull = 2,
|
||||||
} PimMemoryReportLevel;
|
} PimMemoryReportLevel;
|
||||||
|
|
||||||
|
typedef enum {
|
||||||
|
PimConvLoweringAuto = 0,
|
||||||
|
PimConvLoweringLegacy = 1,
|
||||||
|
PimConvLoweringDepthwise = 2,
|
||||||
|
PimConvLoweringPackedIm2Col = 3,
|
||||||
|
PimConvLoweringStreamedPatch = 4,
|
||||||
|
PimConvLoweringStreamedPacked = 5,
|
||||||
|
PimConvLoweringOutputChannelTiled = 6,
|
||||||
|
PimConvLoweringInputKTiled = 7,
|
||||||
|
PimConvLoweringTiled2D = 8,
|
||||||
|
} PimConvLoweringType;
|
||||||
|
|
||||||
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
||||||
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
||||||
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
|
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
|
||||||
extern llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport;
|
extern llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport;
|
||||||
|
extern llvm::cl::opt<PimConvLoweringType> pimConvLowering;
|
||||||
|
|
||||||
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
||||||
extern llvm::cl::opt<bool> pimDisableMemoryCoalescing;
|
extern llvm::cl::opt<bool> pimDisableMemoryCoalescing;
|
||||||
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
||||||
extern llvm::cl::opt<bool> pimEmitJson;
|
extern llvm::cl::opt<bool> pimEmitJson;
|
||||||
|
extern llvm::cl::opt<bool> pimReportConvLowering;
|
||||||
|
|
||||||
extern llvm::cl::opt<size_t> crossbarSize;
|
extern llvm::cl::opt<size_t> crossbarSize;
|
||||||
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
||||||
extern llvm::cl::opt<long> coresCount;
|
extern llvm::cl::opt<long> coresCount;
|
||||||
|
extern llvm::cl::opt<uint64_t> pimConvIm2colMaxElements;
|
||||||
|
extern llvm::cl::opt<uint64_t> pimConvStreamChunkPositions;
|
||||||
|
|
||||||
bool hasExplicitPimCoreCount();
|
bool hasExplicitPimCoreCount();
|
||||||
void verifyExplicitPimCoreCount();
|
void verifyExplicitPimCoreCount();
|
||||||
|
|||||||
@@ -19,9 +19,11 @@ using namespace mlir;
|
|||||||
|
|
||||||
namespace onnx_mlir {
|
namespace onnx_mlir {
|
||||||
|
|
||||||
bool isWeightLikeComputeOperand(Value value) {
|
static bool isWeightMaterializationValue(Value value, bool requireMatrixShape) {
|
||||||
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
|
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
|
||||||
if (!rankedType || !isMatrixShape(rankedType.getShape()))
|
if (!rankedType)
|
||||||
|
return false;
|
||||||
|
if (requireMatrixShape && !isMatrixShape(rankedType.getShape()))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||||
@@ -29,8 +31,14 @@ bool isWeightLikeComputeOperand(Value value) {
|
|||||||
while (auto* definingOp = value.getDefiningOp()) {
|
while (auto* definingOp = value.getDefiningOp()) {
|
||||||
if (!visited.insert(definingOp).second)
|
if (!visited.insert(definingOp).second)
|
||||||
return false;
|
return false;
|
||||||
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
|
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp)) {
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(value.getType());
|
||||||
|
if (!sourceType)
|
||||||
|
return false;
|
||||||
|
if (requireMatrixShape && !isMatrixShape(sourceType.getShape()))
|
||||||
|
return false;
|
||||||
return true;
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||||
value = extractSliceOp.getSource();
|
value = extractSliceOp.getSource();
|
||||||
@@ -55,6 +63,8 @@ bool isWeightLikeComputeOperand(Value value) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isWeightLikeComputeOperand(Value value) { return isWeightMaterializationValue(value, /*requireMatrixShape=*/true); }
|
||||||
|
|
||||||
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
||||||
if (auto mapped = mapper.lookupOrNull(value))
|
if (auto mapped = mapper.lookupOrNull(value))
|
||||||
return cast<Value>(mapped);
|
return cast<Value>(mapped);
|
||||||
@@ -91,7 +101,7 @@ FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isWeightLikeComputeOperand(operand)) {
|
if (isWeightMaterializationValue(operand, /*requireMatrixShape=*/false)) {
|
||||||
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
||||||
if (failed(clonedOperand))
|
if (failed(clonedOperand))
|
||||||
return failure();
|
return failure();
|
||||||
|
|||||||
@@ -26,6 +26,82 @@ static bool isUsedOnlyAsExplicitHostOperand(Value value) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool isMaterializableExternalTensorOp(Operation* op) {
|
||||||
|
return isa<spatial::SpatChannelReceiveOp,
|
||||||
|
spatial::SpatExtractRowsOp,
|
||||||
|
tensor::ExtractSliceOp,
|
||||||
|
tensor::ExpandShapeOp,
|
||||||
|
tensor::CollapseShapeOp>(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO REMOVE THIS UGLY FIX
|
||||||
|
//TODO: Remove this helper once compute_batch external tensor captures are
|
||||||
|
// fixed at the producer side.
|
||||||
|
//
|
||||||
|
// This function is a temporary SpatialToPim repair path. It clones selected
|
||||||
|
// external tensor producers, such as channel_receive and tensor view/slice ops,
|
||||||
|
// into the new pim.core_batch body when the old spat.compute_batch body refers
|
||||||
|
// to tensor values defined outside the batch.
|
||||||
|
//
|
||||||
|
// The real invariant should be stronger:
|
||||||
|
//
|
||||||
|
// A spat.compute_batch body must not capture external tensor values.
|
||||||
|
// Every tensor used inside the body must be either:
|
||||||
|
// - a compute_batch block argument,
|
||||||
|
// - defined inside the compute_batch body,
|
||||||
|
// - or a legal constant-like value.
|
||||||
|
//
|
||||||
|
// If this invariant is violated, the responsible producer, most likely merge
|
||||||
|
// schedule materialization, should emit verifier-clean Spatial IR instead of
|
||||||
|
// relying on SpatialToPim to clone external producer chains later.
|
||||||
|
//
|
||||||
|
// After that producer-side fix:
|
||||||
|
// 1. remove isMaterializableExternalTensorOp,
|
||||||
|
// 2. remove materializeExternalTensorValue,
|
||||||
|
// 3. make lowerComputeBatchOp emit a hard diagnostic for any unmapped external
|
||||||
|
// tensor operand,
|
||||||
|
// 4. keep/strengthen the Spatial verifier so the invalid capture is rejected
|
||||||
|
// before SpatialToPim.
|
||||||
|
//
|
||||||
|
// Be careful not to replace every external tensor capture with a normal
|
||||||
|
// compute_batch input blindly: host-backed tensors and explicit inter-core
|
||||||
|
// communication have different semantics. In particular, channel_receive-like
|
||||||
|
// values should be materialized through the communication model, not silently
|
||||||
|
// treated as host inputs.
|
||||||
|
static FailureOr<Value> materializeExternalTensorValue(IRRewriter& rewriter,
|
||||||
|
Location loc,
|
||||||
|
Block& oldBlock,
|
||||||
|
Value value,
|
||||||
|
IRMapping& mapper) {
|
||||||
|
if (mapper.contains(value))
|
||||||
|
return mapper.lookup(value);
|
||||||
|
|
||||||
|
if (!isa<TensorType>(value.getType()))
|
||||||
|
return value;
|
||||||
|
|
||||||
|
Operation* definingOp = value.getDefiningOp();
|
||||||
|
if (!definingOp || definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (definingOp->getBlock() == &oldBlock)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (!isMaterializableExternalTensorOp(definingOp))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
for (Value operand : definingOp->getOperands()) {
|
||||||
|
FailureOr<Value> materializedOperand = materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper);
|
||||||
|
if (succeeded(materializedOperand))
|
||||||
|
mapper.map(operand, *materializedOperand);
|
||||||
|
}
|
||||||
|
|
||||||
|
Operation* cloned = rewriter.clone(*definingOp, mapper);
|
||||||
|
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
|
||||||
|
mapper.map(originalResult, clonedResult);
|
||||||
|
|
||||||
|
return mapper.lookup(value);
|
||||||
|
}
|
||||||
|
|
||||||
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
||||||
size_t& fallbackCoreId) {
|
size_t& fallbackCoreId) {
|
||||||
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||||
@@ -264,9 +340,18 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
|||||||
Operation* definingOp = operand.getDefiningOp();
|
Operation* definingOp = operand.getDefiningOp();
|
||||||
if (definingOp && definingOp->getBlock() == &oldBlock)
|
if (definingOp && definingOp->getBlock() == &oldBlock)
|
||||||
continue;
|
continue;
|
||||||
|
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||||
|
continue;
|
||||||
|
|
||||||
return computeBatchOp.emitOpError(
|
if (succeeded(materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper)))
|
||||||
"expected external tensor communication to be materialized in Spatial before batch lowering");
|
continue;
|
||||||
|
|
||||||
|
InFlightDiagnostic diagnostic =
|
||||||
|
computeBatchOp.emitOpError("expected external tensor communication to be materialized in Spatial before batch lowering");
|
||||||
|
diagnostic << " while cloning nested op '" << op.getName() << "' tensor operand #" << operandIndex;
|
||||||
|
if (definingOp)
|
||||||
|
diagnostic << " from external producer '" << definingOp->getName() << "'";
|
||||||
|
return diagnostic;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operation* cloned = rewriter.clone(op, mapper);
|
Operation* cloned = rewriter.clone(op, mapper);
|
||||||
|
|||||||
@@ -2018,6 +2018,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
|||||||
struct PendingProjectedTransferDescriptor {
|
struct PendingProjectedTransferDescriptor {
|
||||||
ProjectedBatchInputKey inputKey;
|
ProjectedBatchInputKey inputKey;
|
||||||
Operation* extractOp = nullptr;
|
Operation* extractOp = nullptr;
|
||||||
|
RankedTensorType sourceType;
|
||||||
RankedTensorType fragmentType;
|
RankedTensorType fragmentType;
|
||||||
SmallVector<int64_t, 4> fragmentShape;
|
SmallVector<int64_t, 4> fragmentShape;
|
||||||
SmallVector<SmallVector<SmallVector<int64_t, 4>, 16>, 8> fragmentOffsetsByLane;
|
SmallVector<SmallVector<SmallVector<int64_t, 4>, 16>, 8> fragmentOffsetsByLane;
|
||||||
@@ -2029,6 +2030,20 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
|||||||
|
|
||||||
DenseMap<ProducerKey, DenseMap<ClassId, PendingProjectedTransferDescriptor>, ProducerKeyInfo> pending;
|
DenseMap<ProducerKey, DenseMap<ClassId, PendingProjectedTransferDescriptor>, ProducerKeyInfo> pending;
|
||||||
|
|
||||||
|
const auto isIdentityProjectedTransfer = [&](const PendingProjectedTransferDescriptor& descriptor) {
|
||||||
|
if (!descriptor.sourceType || descriptor.sourceType != descriptor.fragmentType)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (descriptor.fragmentOffsetsByLane.size() != 1)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
ArrayRef<SmallVector<int64_t, 4>> fragments = descriptor.fragmentOffsetsByLane.front();
|
||||||
|
if (fragments.size() != 1)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
return llvm::all_of(fragments.front(), [](int64_t offset) { return offset == 0; });
|
||||||
|
};
|
||||||
|
|
||||||
const auto appendEvaluatedFragments = [&](PendingProjectedTransferDescriptor& descriptor,
|
const auto appendEvaluatedFragments = [&](PendingProjectedTransferDescriptor& descriptor,
|
||||||
unsigned targetLane,
|
unsigned targetLane,
|
||||||
const AffineProjectedInputSliceMatch& match,
|
const AffineProjectedInputSliceMatch& match,
|
||||||
@@ -2117,6 +2132,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
|||||||
if (descriptor.fragmentOffsetsByLane.empty()) {
|
if (descriptor.fragmentOffsetsByLane.empty()) {
|
||||||
descriptor.inputKey = {batch.getOperation(), static_cast<unsigned>(inputIndex)};
|
descriptor.inputKey = {batch.getOperation(), static_cast<unsigned>(inputIndex)};
|
||||||
descriptor.extractOp = match->extract.getOperation();
|
descriptor.extractOp = match->extract.getOperation();
|
||||||
|
descriptor.sourceType = match->sourceType;
|
||||||
descriptor.fragmentType = match->fragmentType;
|
descriptor.fragmentType = match->fragmentType;
|
||||||
descriptor.fragmentShape = match->fragmentShape;
|
descriptor.fragmentShape = match->fragmentShape;
|
||||||
descriptor.fragmentOffsetsByLane.resize(targetClass.isBatch ? targetClass.cpus.size() : 1);
|
descriptor.fragmentOffsetsByLane.resize(targetClass.isBatch ? targetClass.cpus.size() : 1);
|
||||||
@@ -2132,7 +2148,8 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
|||||||
|
|
||||||
ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast<unsigned>(inputIndex)};
|
ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast<unsigned>(inputIndex)};
|
||||||
if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != match->extract.getOperation()
|
if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != match->extract.getOperation()
|
||||||
|| descriptor.fragmentType != match->fragmentType || descriptor.fragmentShape != match->fragmentShape
|
|| descriptor.sourceType != match->sourceType || descriptor.fragmentType != match->fragmentType
|
||||||
|
|| descriptor.fragmentShape != match->fragmentShape
|
||||||
|| descriptor.loopLowerBounds.size() != match->loops.size()) {
|
|| descriptor.loopLowerBounds.size() != match->loops.size()) {
|
||||||
descriptor.invalid = true;
|
descriptor.invalid = true;
|
||||||
continue;
|
continue;
|
||||||
@@ -2175,6 +2192,8 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
|||||||
continue;
|
continue;
|
||||||
if (pendingDescriptor.fragmentOffsetsByLane.empty())
|
if (pendingDescriptor.fragmentOffsetsByLane.empty())
|
||||||
continue;
|
continue;
|
||||||
|
if (isIdentityProjectedTransfer(pendingDescriptor))
|
||||||
|
continue;
|
||||||
|
|
||||||
MaterializedClass& targetClass = state.classes[targetClassId];
|
MaterializedClass& targetClass = state.classes[targetClassId];
|
||||||
ProjectedTransferDescriptor descriptor;
|
ProjectedTransferDescriptor descriptor;
|
||||||
@@ -2755,8 +2774,14 @@ FailureOr<ScalarSourceFanoutPlan> buildScalarSourceFanoutPlan(MaterializerState&
|
|||||||
if (*descriptor) {
|
if (*descriptor) {
|
||||||
const ProjectedTransferDescriptor& projectedDescriptor = **descriptor;
|
const ProjectedTransferDescriptor& projectedDescriptor = **descriptor;
|
||||||
|
|
||||||
if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType())
|
if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType()) {
|
||||||
return targetClass.op->emitError("scalar projected receive unexpectedly uses the full producer tensor type");
|
if (!fanoutPlan.ordinaryMessages)
|
||||||
|
fanoutPlan.ordinaryMessages = MessageVector {};
|
||||||
|
fanoutPlan.ordinaryMessages->append(
|
||||||
|
receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds);
|
||||||
|
fanoutPlan.receivePlans.push_back(std::move(receivePlan));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
receivePlan.receiveType = projectedDescriptor.payloadType;
|
receivePlan.receiveType = projectedDescriptor.payloadType;
|
||||||
receivePlan.projectedExtractOp = projectedDescriptor.extractOp;
|
receivePlan.projectedExtractOp = projectedDescriptor.extractOp;
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ namespace {
|
|||||||
using namespace onnx_mlir::compact_asm;
|
using namespace onnx_mlir::compact_asm;
|
||||||
using SpatCompute = spatial::SpatCompute;
|
using SpatCompute = spatial::SpatCompute;
|
||||||
using SpatComputeBatch = spatial::SpatComputeBatch;
|
using SpatComputeBatch = spatial::SpatComputeBatch;
|
||||||
using spatial::getProducerValueRef;
|
|
||||||
|
|
||||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
|
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
|
||||||
@@ -187,13 +186,23 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
|||||||
SmallVector<int32_t> coreIds;
|
SmallVector<int32_t> coreIds;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//TODO Used for report refactor
|
||||||
|
struct CollectorConcatRow {
|
||||||
|
uint64_t computeId = 0;
|
||||||
|
int32_t coreId = -1;
|
||||||
|
uint64_t operandCount = 0;
|
||||||
|
};
|
||||||
|
|
||||||
uint64_t totalComputeOps = 0;
|
uint64_t totalComputeOps = 0;
|
||||||
uint64_t totalLogicalComputes = 0;
|
uint64_t totalLogicalComputes = 0;
|
||||||
uint64_t totalBatchComputeOps = 0;
|
uint64_t totalBatchComputeOps = 0;
|
||||||
uint64_t totalInstructionCount = 0;
|
uint64_t totalInstructionCount = 0;
|
||||||
uint64_t totalCrossbarCount = 0;
|
uint64_t totalCrossbarCount = 0;
|
||||||
uint64_t nextBatchId = 0;
|
uint64_t nextBatchId = 0;
|
||||||
|
//TODO Used for report refactor
|
||||||
std::vector<ReportRow> collectedData;
|
std::vector<ReportRow> collectedData;
|
||||||
|
//TODO Used for report refactor
|
||||||
|
std::vector<CollectorConcatRow> collectorConcatRows;
|
||||||
|
|
||||||
auto getPerInstanceCrossbarCount = [&](Operation* op) -> uint64_t {
|
auto getPerInstanceCrossbarCount = [&](Operation* op) -> uint64_t {
|
||||||
return static_cast<uint64_t>(spatial::collectDistinctCrossbarWeights(op).size());
|
return static_cast<uint64_t>(spatial::collectDistinctCrossbarWeights(op).size());
|
||||||
@@ -206,7 +215,15 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
|||||||
SmallVector<int32_t> coreIds;
|
SmallVector<int32_t> coreIds;
|
||||||
if (auto coreId = getComputeCoreId(spatCompute))
|
if (auto coreId = getComputeCoreId(spatCompute))
|
||||||
coreIds.push_back(*coreId);
|
coreIds.push_back(*coreId);
|
||||||
collectedData.push_back({totalComputeOps++, 1, perInstanceCrossbarCount, numInst, false, coreIds});
|
uint64_t computeId = totalComputeOps++;
|
||||||
|
collectedData.push_back({computeId, 1, perInstanceCrossbarCount, numInst, false, coreIds});
|
||||||
|
uint64_t maxConcatOperands = 0;
|
||||||
|
spatCompute.getBody().walk([&](spatial::SpatConcatOp concatOp) {
|
||||||
|
maxConcatOperands = std::max<uint64_t>(maxConcatOperands, concatOp.getInputs().size());
|
||||||
|
});
|
||||||
|
//TODO 128 is a magic number
|
||||||
|
if (maxConcatOperands >= 128 && !coreIds.empty())
|
||||||
|
collectorConcatRows.push_back({computeId, coreIds.front(), maxConcatOperands});
|
||||||
totalLogicalComputes += 1;
|
totalLogicalComputes += 1;
|
||||||
totalInstructionCount += numInst;
|
totalInstructionCount += numInst;
|
||||||
totalCrossbarCount += perInstanceCrossbarCount;
|
totalCrossbarCount += perInstanceCrossbarCount;
|
||||||
@@ -238,9 +255,17 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
|||||||
{"Number of used crossbars", std::to_string(totalCrossbarCount) }
|
{"Number of used crossbars", std::to_string(totalCrossbarCount) }
|
||||||
};
|
};
|
||||||
printReportTotalsBlock(os, totalFields);
|
printReportTotalsBlock(os, totalFields);
|
||||||
if (!collectedData.empty())
|
if (!collectedData.empty() || !collectorConcatRows.empty())
|
||||||
os << "\n";
|
os << "\n";
|
||||||
|
|
||||||
|
if (!collectorConcatRows.empty()) {
|
||||||
|
os << "Collector concat materialization:\n";
|
||||||
|
for (const CollectorConcatRow& row : collectorConcatRows)
|
||||||
|
os << "\tmaterialization_kind = single_collector_concat, compute = " << row.computeId
|
||||||
|
<< ", concat_operand_count = " << row.operandCount << ", collector_core = " << row.coreId << "\n";
|
||||||
|
os << "\n";
|
||||||
|
}
|
||||||
|
|
||||||
sortReportEntriesByFirstCore(collectedData);
|
sortReportEntriesByFirstCore(collectedData);
|
||||||
|
|
||||||
for (uint64_t cI = 0; cI < totalComputeOps; ++cI) {
|
for (uint64_t cI = 0; cI < totalComputeOps; ++cI) {
|
||||||
|
|||||||
+10
-3
@@ -23,7 +23,10 @@ MergeSchedulerKind getSchedulerKind() {
|
|||||||
llvm_unreachable("unknown merge scheduler kind");
|
llvm_unreachable("unknown merge scheduler kind");
|
||||||
}
|
}
|
||||||
|
|
||||||
void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result, unsigned long crossbarCapacity) {
|
void verifySchedule(const ComputeGraph& graph,
|
||||||
|
const MergeScheduleResult& result,
|
||||||
|
unsigned long crossbarCapacity,
|
||||||
|
size_t processorCount) {
|
||||||
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
||||||
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
|
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
|
||||||
|
|
||||||
@@ -79,7 +82,8 @@ void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result
|
|||||||
|
|
||||||
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].cost);
|
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].cost);
|
||||||
if (sourceCpu != targetCpu)
|
if (sourceCpu != targetCpu)
|
||||||
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
|
earliestTargetStart = addOrMax(
|
||||||
|
earliestTargetStart, getPeftTransferTime(edge.transferCost, sourceCpu, targetCpu, processorCount));
|
||||||
if (targetStart < earliestTargetStart) {
|
if (targetStart < earliestTargetStart) {
|
||||||
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
|
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
|
||||||
graph.nodes[edge.source].originalOrder,
|
graph.nodes[edge.source].originalOrder,
|
||||||
@@ -115,7 +119,10 @@ MergeScheduleResult MergeSchedulingAnalysis::run() {
|
|||||||
static_cast<unsigned long>(crossbarCountInCore.getValue()),
|
static_cast<unsigned long>(crossbarCountInCore.getValue()),
|
||||||
entryOp->getContext()});
|
entryOp->getContext()});
|
||||||
}
|
}
|
||||||
verifySchedule(graph, schedule, static_cast<unsigned long>(crossbarCountInCore.getValue()));
|
verifySchedule(graph,
|
||||||
|
schedule,
|
||||||
|
static_cast<unsigned long>(crossbarCountInCore.getValue()),
|
||||||
|
options.processorCount);
|
||||||
return schedule;
|
return schedule;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
#include "llvm/Support/ErrorHandling.h"
|
#include "llvm/Support/ErrorHandling.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@@ -21,6 +22,63 @@ struct ScheduledTask {
|
|||||||
Time endTime = 0;
|
Time endTime = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct MeshModel {
|
||||||
|
size_t rows = 1;
|
||||||
|
size_t cols = 1;
|
||||||
|
long double averageDistance = 0.0L;
|
||||||
|
|
||||||
|
static MeshModel infer(size_t processorCount) {
|
||||||
|
MeshModel model;
|
||||||
|
if (processorCount == 0)
|
||||||
|
return model;
|
||||||
|
|
||||||
|
model.rows = static_cast<size_t>(std::sqrt(static_cast<long double>(processorCount)));
|
||||||
|
if (model.rows == 0)
|
||||||
|
model.rows = 1;
|
||||||
|
while (model.rows > 1 && processorCount % model.rows != 0)
|
||||||
|
--model.rows;
|
||||||
|
model.cols = (processorCount + model.rows - 1) / model.rows;
|
||||||
|
|
||||||
|
auto averageAxisDistance = [](size_t size) -> long double {
|
||||||
|
if (size <= 1)
|
||||||
|
return 0.0L;
|
||||||
|
return static_cast<long double>(size * size - 1) / (3.0L * static_cast<long double>(size));
|
||||||
|
};
|
||||||
|
model.averageDistance = averageAxisDistance(model.rows) + averageAxisDistance(model.cols);
|
||||||
|
return model;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<size_t, size_t> getCoord(size_t processor) const {
|
||||||
|
return {processor / cols, processor % cols};
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getDistance(size_t lhs, size_t rhs) const {
|
||||||
|
auto [lhsRow, lhsCol] = getCoord(lhs);
|
||||||
|
auto [rhsRow, rhsCol] = getCoord(rhs);
|
||||||
|
size_t rowDistance = lhsRow > rhsRow ? lhsRow - rhsRow : rhsRow - lhsRow;
|
||||||
|
size_t colDistance = lhsCol > rhsCol ? lhsCol - rhsCol : rhsCol - lhsCol;
|
||||||
|
return rowDistance + colDistance;
|
||||||
|
}
|
||||||
|
|
||||||
|
Time scaleTransferCost(Time transferCost, size_t sourceProcessor, size_t targetProcessor) const {
|
||||||
|
if (sourceProcessor == targetProcessor || transferCost == 0)
|
||||||
|
return 0;
|
||||||
|
long double distance = static_cast<long double>(getDistance(sourceProcessor, targetProcessor));
|
||||||
|
long double scale = averageDistance > 0.0L ? distance / averageDistance : 1.0L;
|
||||||
|
scale = std::max(0.25L, scale);
|
||||||
|
return static_cast<Time>(std::ceil(static_cast<long double>(transferCost) * scale));
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t getCenterDistance(size_t processor) const {
|
||||||
|
auto [row, col] = getCoord(processor);
|
||||||
|
size_t centerRow = rows / 2;
|
||||||
|
size_t centerCol = cols / 2;
|
||||||
|
size_t rowDistance = row > centerRow ? row - centerRow : centerRow - row;
|
||||||
|
size_t colDistance = col > centerCol ? col - centerCol : centerCol - col;
|
||||||
|
return rowDistance + colDistance;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
|
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
|
||||||
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
|
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
|
||||||
std::queue<size_t> readySinks;
|
std::queue<size_t> readySinks;
|
||||||
@@ -77,11 +135,16 @@ void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
Time getPeftTransferTime(Time transferCost, size_t sourceProcessor, size_t targetProcessor, size_t processorCount) {
|
||||||
|
return MeshModel::infer(processorCount).scaleTransferCost(transferCost, sourceProcessor, targetProcessor);
|
||||||
|
}
|
||||||
|
|
||||||
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options) {
|
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options) {
|
||||||
const size_t nodeCount = graph.nodes.size();
|
const size_t nodeCount = graph.nodes.size();
|
||||||
const size_t processorCount = options.processorCount;
|
const size_t processorCount = options.processorCount;
|
||||||
if (processorCount == 0)
|
if (processorCount == 0)
|
||||||
llvm::report_fatal_error("PEFT scheduler: processor count must be positive");
|
llvm::report_fatal_error("PEFT scheduler: processor count must be positive");
|
||||||
|
MeshModel mesh = MeshModel::infer(processorCount);
|
||||||
|
|
||||||
verifyOctTableSize(nodeCount, processorCount);
|
verifyOctTableSize(nodeCount, processorCount);
|
||||||
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
||||||
@@ -89,7 +152,6 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
// MOCK: Replace this with your actual heterogeneous cost lookup.
|
// MOCK: Replace this with your actual heterogeneous cost lookup.
|
||||||
// If graph.nodes[task] is modified to hold a vector of costs per processor, access it here.
|
// If graph.nodes[task] is modified to hold a vector of costs per processor, access it here.
|
||||||
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].cost; };
|
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].cost; };
|
||||||
|
|
||||||
std::vector<Time> oct(nodeCount * processorCount, 0);
|
std::vector<Time> oct(nodeCount * processorCount, 0);
|
||||||
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
||||||
|
|
||||||
@@ -177,6 +239,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
Time bestEft = 0;
|
Time bestEft = 0;
|
||||||
Time bestOeft = std::numeric_limits<Time>::max();
|
Time bestOeft = std::numeric_limits<Time>::max();
|
||||||
unsigned int bestOverlapCount = 0;
|
unsigned int bestOverlapCount = 0;
|
||||||
|
size_t bestCenterDistance = std::numeric_limits<size_t>::max();
|
||||||
bool crossbarRejected = false;
|
bool crossbarRejected = false;
|
||||||
|
|
||||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||||
@@ -191,7 +254,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
Time dataReady = 0;
|
Time dataReady = 0;
|
||||||
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
||||||
const ScheduledTask& predSchedule = schedules[pred];
|
const ScheduledTask& predSchedule = schedules[pred];
|
||||||
Time commPenalty = predSchedule.processor == processor ? 0 : comm;
|
Time commPenalty = getPeftTransferTime(comm, predSchedule.processor, processor, processorCount);
|
||||||
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
|
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -218,6 +281,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
|
|
||||||
Time eft = addOrMax(est, computeCost);
|
Time eft = addOrMax(est, computeCost);
|
||||||
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
||||||
|
size_t centerDistance = mesh.getCenterDistance(processor);
|
||||||
|
|
||||||
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
||||||
|| (oeft == bestOeft && eft == bestEft && est < bestEst)) {
|
|| (oeft == bestOeft && eft == bestEft && est < bestEst)) {
|
||||||
@@ -226,13 +290,25 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
|||||||
bestEft = eft;
|
bestEft = eft;
|
||||||
bestOeft = oeft;
|
bestOeft = oeft;
|
||||||
bestOverlapCount = overlapCount;
|
bestOverlapCount = overlapCount;
|
||||||
|
bestCenterDistance = centerDistance;
|
||||||
}
|
}
|
||||||
else if (oeft == bestOeft && eft == bestEft && est < bestEst && overlapCount < bestOverlapCount) {
|
else if (oeft == bestOeft && eft == bestEft && est == bestEst
|
||||||
|
&& centerDistance < bestCenterDistance) {
|
||||||
bestProcessor = processor;
|
bestProcessor = processor;
|
||||||
bestEst = est;
|
bestEst = est;
|
||||||
bestEft = eft;
|
bestEft = eft;
|
||||||
bestOeft = oeft;
|
bestOeft = oeft;
|
||||||
bestOverlapCount = overlapCount;
|
bestOverlapCount = overlapCount;
|
||||||
|
bestCenterDistance = centerDistance;
|
||||||
|
}
|
||||||
|
else if (oeft == bestOeft && eft == bestEft && est == bestEst
|
||||||
|
&& centerDistance == bestCenterDistance && overlapCount < bestOverlapCount) {
|
||||||
|
bestProcessor = processor;
|
||||||
|
bestEst = est;
|
||||||
|
bestEft = eft;
|
||||||
|
bestOeft = oeft;
|
||||||
|
bestOverlapCount = overlapCount;
|
||||||
|
bestCenterDistance = centerDistance;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ struct PeftScheduleOptions {
|
|||||||
mlir::MLIRContext* context = nullptr;
|
mlir::MLIRContext* context = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Time getPeftTransferTime(Time transferCost, size_t sourceProcessor, size_t targetProcessor, size_t processorCount);
|
||||||
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options);
|
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options);
|
||||||
|
|
||||||
} // namespace spatial
|
} // namespace spatial
|
||||||
|
|||||||
Reference in New Issue
Block a user