Add register reuse + peft scheduler cost model + Useless merger
This commit is contained in:
@@ -17,6 +17,20 @@ std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRe
|
||||
|
||||
std::fstream openReportFile(const std::string& name) { return openReportFileWithExtension(name, "txt"); }
|
||||
|
||||
std::fstream openAppendedReportFileWithExtension(const std::string& name, llvm::StringRef extension) {
|
||||
std::string outputDir = getOutputDir();
|
||||
if (outputDir.empty())
|
||||
return {};
|
||||
|
||||
std::string reportsDir = outputDir + "/reports";
|
||||
createDirectory(reportsDir);
|
||||
return std::fstream(reportsDir + "/" + name + "." + extension.str(), std::ios::out | std::ios::app);
|
||||
}
|
||||
|
||||
std::fstream openAppendedReportFile(const std::string& name) {
|
||||
return openAppendedReportFileWithExtension(name, "txt");
|
||||
}
|
||||
|
||||
std::string formatReportMemory(uint64_t bytes) {
|
||||
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
|
||||
int i = 0;
|
||||
|
||||
@@ -12,6 +12,8 @@ namespace onnx_mlir {
|
||||
|
||||
std::fstream openReportFile(const std::string& name);
|
||||
std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRef extension);
|
||||
std::fstream openAppendedReportFile(const std::string& name);
|
||||
std::fstream openAppendedReportFileWithExtension(const std::string& name, llvm::StringRef extension);
|
||||
std::string formatReportMemory(uint64_t bytes);
|
||||
|
||||
struct ReportField {
|
||||
|
||||
@@ -588,13 +588,37 @@ void PimCodeGen::emitInstruction(const pim_binary::InstructionRecord& instructio
|
||||
++emittedInstructionCount;
|
||||
if (coreJsonStream)
|
||||
*coreJsonStream << json::Value(pim_binary::makeInstructionJson(instruction)) << ',';
|
||||
updateScalarRegisterCache(instruction);
|
||||
}
|
||||
|
||||
void PimCodeGen::updateScalarRegisterCache(const pim_binary::InstructionRecord& instruction) const {
|
||||
switch (instruction.opcode) {
|
||||
case pim_binary::Opcode::sldi:
|
||||
scalarRegisterValues[instruction.rd] = instruction.r2OrImm;
|
||||
break;
|
||||
case pim_binary::Opcode::sld:
|
||||
case pim_binary::Opcode::sadd:
|
||||
case pim_binary::Opcode::ssub:
|
||||
case pim_binary::Opcode::smul:
|
||||
case pim_binary::Opcode::saddi:
|
||||
case pim_binary::Opcode::smuli:
|
||||
scalarRegisterValues[instruction.rd].reset();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const {
|
||||
auto registerIndex = pim::checkedU8OrCrash(registerNumber, "register number");
|
||||
auto immediateValue = pim::checkedI32OrCrash(immediate, "register immediate");
|
||||
if (scalarRegisterValues[registerIndex] == immediateValue)
|
||||
return;
|
||||
|
||||
pim_binary::InstructionRecord instruction;
|
||||
instruction.opcode = pim_binary::Opcode::sldi;
|
||||
instruction.rd = static_cast<uint8_t>(registerNumber);
|
||||
instruction.r2OrImm = pim::checkedI32OrCrash(immediate, "register immediate");
|
||||
instruction.rd = registerIndex;
|
||||
instruction.r2OrImm = immediateValue;
|
||||
emitInstruction(instruction);
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "llvm/Support/JSON.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <array>
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
@@ -170,6 +171,7 @@ class PimCodeGen {
|
||||
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
|
||||
std::optional<unsigned> batchLane;
|
||||
mutable uint32_t emittedInstructionCount = 0;
|
||||
mutable std::array<std::optional<int32_t>, 256> scalarRegisterValues = {};
|
||||
|
||||
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
||||
return memory.getValueAddress(value, knowledge, batchLane);
|
||||
@@ -177,6 +179,7 @@ class PimCodeGen {
|
||||
size_t remapCoreId(size_t coreId) const;
|
||||
|
||||
void emitInstruction(const pim_binary::InstructionRecord& instruction) const;
|
||||
void updateScalarRegisterCache(const pim_binary::InstructionRecord& instruction) const;
|
||||
|
||||
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
|
||||
void setupRd(size_t rdAddress, size_t rdOffset) const;
|
||||
|
||||
@@ -32,6 +32,31 @@ llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport(
|
||||
llvm::cl::init(PimMemoryReportNone),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<PimConvLoweringType> pimConvLowering(
|
||||
"pim-conv-lowering",
|
||||
llvm::cl::desc("Convolution lowering strategy for PIM"),
|
||||
llvm::cl::values(clEnumValN(PimConvLoweringAuto, "auto", "Select the Conv lowering strategy automatically")),
|
||||
llvm::cl::values(clEnumValN(PimConvLoweringLegacy, "legacy", "Use the legacy explicit-im2col Conv lowering")),
|
||||
llvm::cl::values(clEnumValN(PimConvLoweringDepthwise, "depthwise", "Force the depthwise-specialized Conv lowering")),
|
||||
llvm::cl::values(
|
||||
clEnumValN(PimConvLoweringPackedIm2Col, "packed-im2col", "Use explicit im2col with packed multi-position GEMM")),
|
||||
llvm::cl::values(clEnumValN(PimConvLoweringStreamedPatch,
|
||||
"streamed-patch",
|
||||
"Use streamed/chunked im2col rows without multi-position packing")),
|
||||
llvm::cl::values(clEnumValN(PimConvLoweringStreamedPacked,
|
||||
"streamed-packed",
|
||||
"Use streamed/chunked im2col rows with packed multi-position GEMM")),
|
||||
llvm::cl::values(clEnumValN(PimConvLoweringOutputChannelTiled,
|
||||
"output-channel-tiled",
|
||||
"Force Conv lowering that relies on Gemm output-channel tiling")),
|
||||
llvm::cl::values(
|
||||
clEnumValN(PimConvLoweringInputKTiled, "input-k-tiled", "Force Conv lowering that relies on Gemm K tiling")),
|
||||
llvm::cl::values(clEnumValN(PimConvLoweringTiled2D,
|
||||
"tiled-2d",
|
||||
"Force Conv lowering that relies on Gemm 2D K/C tiling")),
|
||||
llvm::cl::init(PimConvLoweringAuto),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
pimOnlyCodegen("pim-only-codegen",
|
||||
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
|
||||
@@ -49,6 +74,23 @@ llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl",
|
||||
llvm::cl::init(false),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<uint64_t> pimConvIm2colMaxElements(
|
||||
"pim-conv-im2col-max-elements",
|
||||
llvm::cl::desc("Maximum number of im2col elements to materialize globally for one Conv before streaming/chunking"),
|
||||
llvm::cl::init(1ull << 20),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<uint64_t> pimConvStreamChunkPositions(
|
||||
"pim-conv-stream-chunk-positions",
|
||||
llvm::cl::desc("Maximum number of Conv output positions to materialize in one streamed chunk"),
|
||||
llvm::cl::init(1024),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<bool> pimReportConvLowering("pim-report-conv-lowering",
|
||||
llvm::cl::desc("Emit a bounded Conv lowering report"),
|
||||
llvm::cl::init(true),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
|
||||
llvm::cl::desc("Also emit per-core JSON instruction files alongside binary .pim files"),
|
||||
llvm::cl::init(false),
|
||||
|
||||
@@ -30,19 +30,35 @@ typedef enum {
|
||||
PimMemoryReportFull = 2,
|
||||
} PimMemoryReportLevel;
|
||||
|
||||
typedef enum {
|
||||
PimConvLoweringAuto = 0,
|
||||
PimConvLoweringLegacy = 1,
|
||||
PimConvLoweringDepthwise = 2,
|
||||
PimConvLoweringPackedIm2Col = 3,
|
||||
PimConvLoweringStreamedPatch = 4,
|
||||
PimConvLoweringStreamedPacked = 5,
|
||||
PimConvLoweringOutputChannelTiled = 6,
|
||||
PimConvLoweringInputKTiled = 7,
|
||||
PimConvLoweringTiled2D = 8,
|
||||
} PimConvLoweringType;
|
||||
|
||||
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
||||
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
||||
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
|
||||
extern llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport;
|
||||
extern llvm::cl::opt<PimConvLoweringType> pimConvLowering;
|
||||
|
||||
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
||||
extern llvm::cl::opt<bool> pimDisableMemoryCoalescing;
|
||||
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
||||
extern llvm::cl::opt<bool> pimEmitJson;
|
||||
extern llvm::cl::opt<bool> pimReportConvLowering;
|
||||
|
||||
extern llvm::cl::opt<size_t> crossbarSize;
|
||||
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
||||
extern llvm::cl::opt<long> coresCount;
|
||||
extern llvm::cl::opt<uint64_t> pimConvIm2colMaxElements;
|
||||
extern llvm::cl::opt<uint64_t> pimConvStreamChunkPositions;
|
||||
|
||||
bool hasExplicitPimCoreCount();
|
||||
void verifyExplicitPimCoreCount();
|
||||
|
||||
@@ -19,9 +19,11 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool isWeightLikeComputeOperand(Value value) {
|
||||
static bool isWeightMaterializationValue(Value value, bool requireMatrixShape) {
|
||||
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!rankedType || !isMatrixShape(rankedType.getShape()))
|
||||
if (!rankedType)
|
||||
return false;
|
||||
if (requireMatrixShape && !isMatrixShape(rankedType.getShape()))
|
||||
return false;
|
||||
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
@@ -29,8 +31,14 @@ bool isWeightLikeComputeOperand(Value value) {
|
||||
while (auto* definingOp = value.getDefiningOp()) {
|
||||
if (!visited.insert(definingOp).second)
|
||||
return false;
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp)) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!sourceType)
|
||||
return false;
|
||||
if (requireMatrixShape && !isMatrixShape(sourceType.getShape()))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
value = extractSliceOp.getSource();
|
||||
@@ -55,6 +63,8 @@ bool isWeightLikeComputeOperand(Value value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool isWeightLikeComputeOperand(Value value) { return isWeightMaterializationValue(value, /*requireMatrixShape=*/true); }
|
||||
|
||||
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
||||
if (auto mapped = mapper.lookupOrNull(value))
|
||||
return cast<Value>(mapped);
|
||||
@@ -91,7 +101,7 @@ FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isWeightLikeComputeOperand(operand)) {
|
||||
if (isWeightMaterializationValue(operand, /*requireMatrixShape=*/false)) {
|
||||
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
||||
if (failed(clonedOperand))
|
||||
return failure();
|
||||
|
||||
@@ -26,6 +26,82 @@ static bool isUsedOnlyAsExplicitHostOperand(Value value) {
|
||||
});
|
||||
}
|
||||
|
||||
static bool isMaterializableExternalTensorOp(Operation* op) {
|
||||
return isa<spatial::SpatChannelReceiveOp,
|
||||
spatial::SpatExtractRowsOp,
|
||||
tensor::ExtractSliceOp,
|
||||
tensor::ExpandShapeOp,
|
||||
tensor::CollapseShapeOp>(op);
|
||||
}
|
||||
|
||||
//TODO REMOVE THIS UGLY FIX
|
||||
//TODO: Remove this helper once compute_batch external tensor captures are
|
||||
// fixed at the producer side.
|
||||
//
|
||||
// This function is a temporary SpatialToPim repair path. It clones selected
|
||||
// external tensor producers, such as channel_receive and tensor view/slice ops,
|
||||
// into the new pim.core_batch body when the old spat.compute_batch body refers
|
||||
// to tensor values defined outside the batch.
|
||||
//
|
||||
// The real invariant should be stronger:
|
||||
//
|
||||
// A spat.compute_batch body must not capture external tensor values.
|
||||
// Every tensor used inside the body must be either:
|
||||
// - a compute_batch block argument,
|
||||
// - defined inside the compute_batch body,
|
||||
// - or a legal constant-like value.
|
||||
//
|
||||
// If this invariant is violated, the responsible producer, most likely merge
|
||||
// schedule materialization, should emit verifier-clean Spatial IR instead of
|
||||
// relying on SpatialToPim to clone external producer chains later.
|
||||
//
|
||||
// After that producer-side fix:
|
||||
// 1. remove isMaterializableExternalTensorOp,
|
||||
// 2. remove materializeExternalTensorValue,
|
||||
// 3. make lowerComputeBatchOp emit a hard diagnostic for any unmapped external
|
||||
// tensor operand,
|
||||
// 4. keep/strengthen the Spatial verifier so the invalid capture is rejected
|
||||
// before SpatialToPim.
|
||||
//
|
||||
// Be careful not to replace every external tensor capture with a normal
|
||||
// compute_batch input blindly: host-backed tensors and explicit inter-core
|
||||
// communication have different semantics. In particular, channel_receive-like
|
||||
// values should be materialized through the communication model, not silently
|
||||
// treated as host inputs.
|
||||
static FailureOr<Value> materializeExternalTensorValue(IRRewriter& rewriter,
|
||||
Location loc,
|
||||
Block& oldBlock,
|
||||
Value value,
|
||||
IRMapping& mapper) {
|
||||
if (mapper.contains(value))
|
||||
return mapper.lookup(value);
|
||||
|
||||
if (!isa<TensorType>(value.getType()))
|
||||
return value;
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp || definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
return failure();
|
||||
|
||||
if (definingOp->getBlock() == &oldBlock)
|
||||
return failure();
|
||||
|
||||
if (!isMaterializableExternalTensorOp(definingOp))
|
||||
return failure();
|
||||
|
||||
for (Value operand : definingOp->getOperands()) {
|
||||
FailureOr<Value> materializedOperand = materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper);
|
||||
if (succeeded(materializedOperand))
|
||||
mapper.map(operand, *materializedOperand);
|
||||
}
|
||||
|
||||
Operation* cloned = rewriter.clone(*definingOp, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
|
||||
return mapper.lookup(value);
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
||||
size_t& fallbackCoreId) {
|
||||
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
@@ -264,9 +340,18 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
Operation* definingOp = operand.getDefiningOp();
|
||||
if (definingOp && definingOp->getBlock() == &oldBlock)
|
||||
continue;
|
||||
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
continue;
|
||||
|
||||
return computeBatchOp.emitOpError(
|
||||
"expected external tensor communication to be materialized in Spatial before batch lowering");
|
||||
if (succeeded(materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper)))
|
||||
continue;
|
||||
|
||||
InFlightDiagnostic diagnostic =
|
||||
computeBatchOp.emitOpError("expected external tensor communication to be materialized in Spatial before batch lowering");
|
||||
diagnostic << " while cloning nested op '" << op.getName() << "' tensor operand #" << operandIndex;
|
||||
if (definingOp)
|
||||
diagnostic << " from external producer '" << definingOp->getName() << "'";
|
||||
return diagnostic;
|
||||
}
|
||||
|
||||
Operation* cloned = rewriter.clone(op, mapper);
|
||||
|
||||
@@ -2018,6 +2018,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
struct PendingProjectedTransferDescriptor {
|
||||
ProjectedBatchInputKey inputKey;
|
||||
Operation* extractOp = nullptr;
|
||||
RankedTensorType sourceType;
|
||||
RankedTensorType fragmentType;
|
||||
SmallVector<int64_t, 4> fragmentShape;
|
||||
SmallVector<SmallVector<SmallVector<int64_t, 4>, 16>, 8> fragmentOffsetsByLane;
|
||||
@@ -2029,6 +2030,20 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
|
||||
DenseMap<ProducerKey, DenseMap<ClassId, PendingProjectedTransferDescriptor>, ProducerKeyInfo> pending;
|
||||
|
||||
const auto isIdentityProjectedTransfer = [&](const PendingProjectedTransferDescriptor& descriptor) {
|
||||
if (!descriptor.sourceType || descriptor.sourceType != descriptor.fragmentType)
|
||||
return false;
|
||||
|
||||
if (descriptor.fragmentOffsetsByLane.size() != 1)
|
||||
return false;
|
||||
|
||||
ArrayRef<SmallVector<int64_t, 4>> fragments = descriptor.fragmentOffsetsByLane.front();
|
||||
if (fragments.size() != 1)
|
||||
return false;
|
||||
|
||||
return llvm::all_of(fragments.front(), [](int64_t offset) { return offset == 0; });
|
||||
};
|
||||
|
||||
const auto appendEvaluatedFragments = [&](PendingProjectedTransferDescriptor& descriptor,
|
||||
unsigned targetLane,
|
||||
const AffineProjectedInputSliceMatch& match,
|
||||
@@ -2117,6 +2132,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
if (descriptor.fragmentOffsetsByLane.empty()) {
|
||||
descriptor.inputKey = {batch.getOperation(), static_cast<unsigned>(inputIndex)};
|
||||
descriptor.extractOp = match->extract.getOperation();
|
||||
descriptor.sourceType = match->sourceType;
|
||||
descriptor.fragmentType = match->fragmentType;
|
||||
descriptor.fragmentShape = match->fragmentShape;
|
||||
descriptor.fragmentOffsetsByLane.resize(targetClass.isBatch ? targetClass.cpus.size() : 1);
|
||||
@@ -2132,7 +2148,8 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
|
||||
ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast<unsigned>(inputIndex)};
|
||||
if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != match->extract.getOperation()
|
||||
|| descriptor.fragmentType != match->fragmentType || descriptor.fragmentShape != match->fragmentShape
|
||||
|| descriptor.sourceType != match->sourceType || descriptor.fragmentType != match->fragmentType
|
||||
|| descriptor.fragmentShape != match->fragmentShape
|
||||
|| descriptor.loopLowerBounds.size() != match->loops.size()) {
|
||||
descriptor.invalid = true;
|
||||
continue;
|
||||
@@ -2175,6 +2192,8 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) {
|
||||
continue;
|
||||
if (pendingDescriptor.fragmentOffsetsByLane.empty())
|
||||
continue;
|
||||
if (isIdentityProjectedTransfer(pendingDescriptor))
|
||||
continue;
|
||||
|
||||
MaterializedClass& targetClass = state.classes[targetClassId];
|
||||
ProjectedTransferDescriptor descriptor;
|
||||
@@ -2755,8 +2774,14 @@ FailureOr<ScalarSourceFanoutPlan> buildScalarSourceFanoutPlan(MaterializerState&
|
||||
if (*descriptor) {
|
||||
const ProjectedTransferDescriptor& projectedDescriptor = **descriptor;
|
||||
|
||||
if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType())
|
||||
return targetClass.op->emitError("scalar projected receive unexpectedly uses the full producer tensor type");
|
||||
if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType()) {
|
||||
if (!fanoutPlan.ordinaryMessages)
|
||||
fanoutPlan.ordinaryMessages = MessageVector {};
|
||||
fanoutPlan.ordinaryMessages->append(
|
||||
receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds);
|
||||
fanoutPlan.receivePlans.push_back(std::move(receivePlan));
|
||||
continue;
|
||||
}
|
||||
|
||||
receivePlan.receiveType = projectedDescriptor.payloadType;
|
||||
receivePlan.projectedExtractOp = projectedDescriptor.extractOp;
|
||||
|
||||
@@ -42,7 +42,6 @@ namespace {
|
||||
using namespace onnx_mlir::compact_asm;
|
||||
using SpatCompute = spatial::SpatCompute;
|
||||
using SpatComputeBatch = spatial::SpatComputeBatch;
|
||||
using spatial::getProducerValueRef;
|
||||
|
||||
static std::optional<int32_t> getComputeCoreId(SpatCompute compute) {
|
||||
if (auto coreIdAttr = compute->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName)) {
|
||||
@@ -187,13 +186,23 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
SmallVector<int32_t> coreIds;
|
||||
};
|
||||
|
||||
//TODO Used for report refactor
|
||||
struct CollectorConcatRow {
|
||||
uint64_t computeId = 0;
|
||||
int32_t coreId = -1;
|
||||
uint64_t operandCount = 0;
|
||||
};
|
||||
|
||||
uint64_t totalComputeOps = 0;
|
||||
uint64_t totalLogicalComputes = 0;
|
||||
uint64_t totalBatchComputeOps = 0;
|
||||
uint64_t totalInstructionCount = 0;
|
||||
uint64_t totalCrossbarCount = 0;
|
||||
uint64_t nextBatchId = 0;
|
||||
//TODO Used for report refactor
|
||||
std::vector<ReportRow> collectedData;
|
||||
//TODO Used for report refactor
|
||||
std::vector<CollectorConcatRow> collectorConcatRows;
|
||||
|
||||
auto getPerInstanceCrossbarCount = [&](Operation* op) -> uint64_t {
|
||||
return static_cast<uint64_t>(spatial::collectDistinctCrossbarWeights(op).size());
|
||||
@@ -206,7 +215,15 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
SmallVector<int32_t> coreIds;
|
||||
if (auto coreId = getComputeCoreId(spatCompute))
|
||||
coreIds.push_back(*coreId);
|
||||
collectedData.push_back({totalComputeOps++, 1, perInstanceCrossbarCount, numInst, false, coreIds});
|
||||
uint64_t computeId = totalComputeOps++;
|
||||
collectedData.push_back({computeId, 1, perInstanceCrossbarCount, numInst, false, coreIds});
|
||||
uint64_t maxConcatOperands = 0;
|
||||
spatCompute.getBody().walk([&](spatial::SpatConcatOp concatOp) {
|
||||
maxConcatOperands = std::max<uint64_t>(maxConcatOperands, concatOp.getInputs().size());
|
||||
});
|
||||
//TODO 128 is a magic number
|
||||
if (maxConcatOperands >= 128 && !coreIds.empty())
|
||||
collectorConcatRows.push_back({computeId, coreIds.front(), maxConcatOperands});
|
||||
totalLogicalComputes += 1;
|
||||
totalInstructionCount += numInst;
|
||||
totalCrossbarCount += perInstanceCrossbarCount;
|
||||
@@ -238,9 +255,17 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
|
||||
{"Number of used crossbars", std::to_string(totalCrossbarCount) }
|
||||
};
|
||||
printReportTotalsBlock(os, totalFields);
|
||||
if (!collectedData.empty())
|
||||
if (!collectedData.empty() || !collectorConcatRows.empty())
|
||||
os << "\n";
|
||||
|
||||
if (!collectorConcatRows.empty()) {
|
||||
os << "Collector concat materialization:\n";
|
||||
for (const CollectorConcatRow& row : collectorConcatRows)
|
||||
os << "\tmaterialization_kind = single_collector_concat, compute = " << row.computeId
|
||||
<< ", concat_operand_count = " << row.operandCount << ", collector_core = " << row.coreId << "\n";
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
sortReportEntriesByFirstCore(collectedData);
|
||||
|
||||
for (uint64_t cI = 0; cI < totalComputeOps; ++cI) {
|
||||
|
||||
+10
-3
@@ -23,7 +23,10 @@ MergeSchedulerKind getSchedulerKind() {
|
||||
llvm_unreachable("unknown merge scheduler kind");
|
||||
}
|
||||
|
||||
void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result, unsigned long crossbarCapacity) {
|
||||
void verifySchedule(const ComputeGraph& graph,
|
||||
const MergeScheduleResult& result,
|
||||
unsigned long crossbarCapacity,
|
||||
size_t processorCount) {
|
||||
llvm::DenseMap<size_t, std::vector<std::pair<size_t, size_t>>> tasksByCpu;
|
||||
tasksByCpu.reserve(result.cpuToLastComputeMap.size());
|
||||
|
||||
@@ -79,7 +82,8 @@ void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result
|
||||
|
||||
Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].cost);
|
||||
if (sourceCpu != targetCpu)
|
||||
earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost);
|
||||
earliestTargetStart = addOrMax(
|
||||
earliestTargetStart, getPeftTransferTime(edge.transferCost, sourceCpu, targetCpu, processorCount));
|
||||
if (targetStart < earliestTargetStart) {
|
||||
std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}",
|
||||
graph.nodes[edge.source].originalOrder,
|
||||
@@ -115,7 +119,10 @@ MergeScheduleResult MergeSchedulingAnalysis::run() {
|
||||
static_cast<unsigned long>(crossbarCountInCore.getValue()),
|
||||
entryOp->getContext()});
|
||||
}
|
||||
verifySchedule(graph, schedule, static_cast<unsigned long>(crossbarCountInCore.getValue()));
|
||||
verifySchedule(graph,
|
||||
schedule,
|
||||
static_cast<unsigned long>(crossbarCountInCore.getValue()),
|
||||
options.processorCount);
|
||||
return schedule;
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
@@ -21,6 +22,63 @@ struct ScheduledTask {
|
||||
Time endTime = 0;
|
||||
};
|
||||
|
||||
struct MeshModel {
|
||||
size_t rows = 1;
|
||||
size_t cols = 1;
|
||||
long double averageDistance = 0.0L;
|
||||
|
||||
static MeshModel infer(size_t processorCount) {
|
||||
MeshModel model;
|
||||
if (processorCount == 0)
|
||||
return model;
|
||||
|
||||
model.rows = static_cast<size_t>(std::sqrt(static_cast<long double>(processorCount)));
|
||||
if (model.rows == 0)
|
||||
model.rows = 1;
|
||||
while (model.rows > 1 && processorCount % model.rows != 0)
|
||||
--model.rows;
|
||||
model.cols = (processorCount + model.rows - 1) / model.rows;
|
||||
|
||||
auto averageAxisDistance = [](size_t size) -> long double {
|
||||
if (size <= 1)
|
||||
return 0.0L;
|
||||
return static_cast<long double>(size * size - 1) / (3.0L * static_cast<long double>(size));
|
||||
};
|
||||
model.averageDistance = averageAxisDistance(model.rows) + averageAxisDistance(model.cols);
|
||||
return model;
|
||||
}
|
||||
|
||||
std::pair<size_t, size_t> getCoord(size_t processor) const {
|
||||
return {processor / cols, processor % cols};
|
||||
}
|
||||
|
||||
size_t getDistance(size_t lhs, size_t rhs) const {
|
||||
auto [lhsRow, lhsCol] = getCoord(lhs);
|
||||
auto [rhsRow, rhsCol] = getCoord(rhs);
|
||||
size_t rowDistance = lhsRow > rhsRow ? lhsRow - rhsRow : rhsRow - lhsRow;
|
||||
size_t colDistance = lhsCol > rhsCol ? lhsCol - rhsCol : rhsCol - lhsCol;
|
||||
return rowDistance + colDistance;
|
||||
}
|
||||
|
||||
Time scaleTransferCost(Time transferCost, size_t sourceProcessor, size_t targetProcessor) const {
|
||||
if (sourceProcessor == targetProcessor || transferCost == 0)
|
||||
return 0;
|
||||
long double distance = static_cast<long double>(getDistance(sourceProcessor, targetProcessor));
|
||||
long double scale = averageDistance > 0.0L ? distance / averageDistance : 1.0L;
|
||||
scale = std::max(0.25L, scale);
|
||||
return static_cast<Time>(std::ceil(static_cast<long double>(transferCost) * scale));
|
||||
}
|
||||
|
||||
size_t getCenterDistance(size_t processor) const {
|
||||
auto [row, col] = getCoord(processor);
|
||||
size_t centerRow = rows / 2;
|
||||
size_t centerCol = cols / 2;
|
||||
size_t rowDistance = row > centerRow ? row - centerRow : centerRow - row;
|
||||
size_t colDistance = col > centerCol ? col - centerCol : centerCol - col;
|
||||
return rowDistance + colDistance;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::vector<size_t>> buildReverseLevels(const ComputeGraph& graph) {
|
||||
std::vector<size_t> remainingSuccessors(graph.nodes.size(), 0);
|
||||
std::queue<size_t> readySinks;
|
||||
@@ -77,11 +135,16 @@ void verifyOctTableSize(size_t nodeCount, size_t processorCount) {
|
||||
|
||||
} // namespace
|
||||
|
||||
Time getPeftTransferTime(Time transferCost, size_t sourceProcessor, size_t targetProcessor, size_t processorCount) {
|
||||
return MeshModel::infer(processorCount).scaleTransferCost(transferCost, sourceProcessor, targetProcessor);
|
||||
}
|
||||
|
||||
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options) {
|
||||
const size_t nodeCount = graph.nodes.size();
|
||||
const size_t processorCount = options.processorCount;
|
||||
if (processorCount == 0)
|
||||
llvm::report_fatal_error("PEFT scheduler: processor count must be positive");
|
||||
MeshModel mesh = MeshModel::infer(processorCount);
|
||||
|
||||
verifyOctTableSize(nodeCount, processorCount);
|
||||
std::vector<std::vector<size_t>> reverseLevels = buildReverseLevels(graph);
|
||||
@@ -89,7 +152,6 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
// MOCK: Replace this with your actual heterogeneous cost lookup.
|
||||
// If graph.nodes[task] is modified to hold a vector of costs per processor, access it here.
|
||||
auto getComputeCost = [&](size_t task, size_t processor) -> Time { return graph.nodes[task].cost; };
|
||||
|
||||
std::vector<Time> oct(nodeCount * processorCount, 0);
|
||||
std::vector<Time> minOctPlusComp(nodeCount, 0);
|
||||
|
||||
@@ -177,6 +239,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
Time bestEft = 0;
|
||||
Time bestOeft = std::numeric_limits<Time>::max();
|
||||
unsigned int bestOverlapCount = 0;
|
||||
size_t bestCenterDistance = std::numeric_limits<size_t>::max();
|
||||
bool crossbarRejected = false;
|
||||
|
||||
for (size_t processor = 0; processor < processorCount; ++processor) {
|
||||
@@ -191,7 +254,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
Time dataReady = 0;
|
||||
for (const auto& [pred, comm] : graph.predecessors[task]) {
|
||||
const ScheduledTask& predSchedule = schedules[pred];
|
||||
Time commPenalty = predSchedule.processor == processor ? 0 : comm;
|
||||
Time commPenalty = getPeftTransferTime(comm, predSchedule.processor, processor, processorCount);
|
||||
dataReady = std::max(dataReady, addOrMax(predSchedule.endTime, commPenalty));
|
||||
}
|
||||
|
||||
@@ -218,6 +281,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
|
||||
Time eft = addOrMax(est, computeCost);
|
||||
Time oeft = addOrMax(eft, oct[task * processorCount + processor]);
|
||||
size_t centerDistance = mesh.getCenterDistance(processor);
|
||||
|
||||
if (oeft < bestOeft || (oeft == bestOeft && eft < bestEft)
|
||||
|| (oeft == bestOeft && eft == bestEft && est < bestEst)) {
|
||||
@@ -226,13 +290,25 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu
|
||||
bestEft = eft;
|
||||
bestOeft = oeft;
|
||||
bestOverlapCount = overlapCount;
|
||||
bestCenterDistance = centerDistance;
|
||||
}
|
||||
else if (oeft == bestOeft && eft == bestEft && est < bestEst && overlapCount < bestOverlapCount) {
|
||||
else if (oeft == bestOeft && eft == bestEft && est == bestEst
|
||||
&& centerDistance < bestCenterDistance) {
|
||||
bestProcessor = processor;
|
||||
bestEst = est;
|
||||
bestEft = eft;
|
||||
bestOeft = oeft;
|
||||
bestOverlapCount = overlapCount;
|
||||
bestCenterDistance = centerDistance;
|
||||
}
|
||||
else if (oeft == bestOeft && eft == bestEft && est == bestEst
|
||||
&& centerDistance == bestCenterDistance && overlapCount < bestOverlapCount) {
|
||||
bestProcessor = processor;
|
||||
bestEst = est;
|
||||
bestEft = eft;
|
||||
bestOeft = oeft;
|
||||
bestOverlapCount = overlapCount;
|
||||
bestCenterDistance = centerDistance;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ struct PeftScheduleOptions {
|
||||
mlir::MLIRContext* context = nullptr;
|
||||
};
|
||||
|
||||
Time getPeftTransferTime(Time transferCost, size_t sourceProcessor, size_t targetProcessor, size_t processorCount);
|
||||
MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftScheduleOptions& options);
|
||||
|
||||
} // namespace spatial
|
||||
|
||||
Reference in New Issue
Block a user