diff --git a/src/PIM/Common/Support/ReportUtils.cpp b/src/PIM/Common/Support/ReportUtils.cpp index 4f1e918..350faa6 100644 --- a/src/PIM/Common/Support/ReportUtils.cpp +++ b/src/PIM/Common/Support/ReportUtils.cpp @@ -17,6 +17,20 @@ std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRe std::fstream openReportFile(const std::string& name) { return openReportFileWithExtension(name, "txt"); } +std::fstream openAppendedReportFileWithExtension(const std::string& name, llvm::StringRef extension) { + std::string outputDir = getOutputDir(); + if (outputDir.empty()) + return {}; + + std::string reportsDir = outputDir + "/reports"; + createDirectory(reportsDir); + return std::fstream(reportsDir + "/" + name + "." + extension.str(), std::ios::out | std::ios::app); +} + +std::fstream openAppendedReportFile(const std::string& name) { + return openAppendedReportFileWithExtension(name, "txt"); +} + std::string formatReportMemory(uint64_t bytes) { const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"}; int i = 0; diff --git a/src/PIM/Common/Support/ReportUtils.hpp b/src/PIM/Common/Support/ReportUtils.hpp index d722fe7..0cfc38b 100644 --- a/src/PIM/Common/Support/ReportUtils.hpp +++ b/src/PIM/Common/Support/ReportUtils.hpp @@ -12,6 +12,8 @@ namespace onnx_mlir { std::fstream openReportFile(const std::string& name); std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRef extension); +std::fstream openAppendedReportFile(const std::string& name); +std::fstream openAppendedReportFileWithExtension(const std::string& name, llvm::StringRef extension); std::string formatReportMemory(uint64_t bytes); struct ReportField { diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 669ff57..30b0ac4 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -588,13 +588,37 @@ void PimCodeGen::emitInstruction(const pim_binary::InstructionRecord& instructio ++emittedInstructionCount; if (coreJsonStream) *coreJsonStream << json::Value(pim_binary::makeInstructionJson(instruction)) << ','; + updateScalarRegisterCache(instruction); +} + +void PimCodeGen::updateScalarRegisterCache(const pim_binary::InstructionRecord& instruction) const { + switch (instruction.opcode) { + case pim_binary::Opcode::sldi: + scalarRegisterValues[instruction.rd] = instruction.r2OrImm; + break; + case pim_binary::Opcode::sld: + case pim_binary::Opcode::sadd: + case pim_binary::Opcode::ssub: + case pim_binary::Opcode::smul: + case pim_binary::Opcode::saddi: + case pim_binary::Opcode::smuli: + scalarRegisterValues[instruction.rd].reset(); + break; + default: + break; + } } void PimCodeGen::genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const { + auto registerIndex = pim::checkedU8OrCrash(registerNumber, "register number"); + auto immediateValue = pim::checkedI32OrCrash(immediate, "register immediate"); + if (scalarRegisterValues[registerIndex] == immediateValue) + return; + pim_binary::InstructionRecord instruction; instruction.opcode = pim_binary::Opcode::sldi; - instruction.rd = static_cast(registerNumber); - instruction.r2OrImm = pim::checkedI32OrCrash(immediate, "register immediate"); + instruction.rd = registerIndex; + instruction.r2OrImm = immediateValue; emitInstruction(instruction); } diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index 59a043e..078d42a 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -9,6 +9,7 @@ #include "llvm/Support/JSON.h" #include "llvm/Support/raw_os_ostream.h" +#include #include #include #include @@ -170,6 +171,7 @@ class PimCodeGen { const llvm::DenseMap& emittedCoreIds; std::optional batchLane; mutable uint32_t emittedInstructionCount = 0; + mutable std::array, 256> scalarRegisterValues = {}; size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const { return memory.getValueAddress(value, knowledge, batchLane); @@ -177,6 +179,7 @@ class PimCodeGen { size_t remapCoreId(size_t coreId) const; void emitInstruction(const pim_binary::InstructionRecord& instruction) const; + void updateScalarRegisterCache(const pim_binary::InstructionRecord& instruction) const; void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const; void setupRd(size_t rdAddress, size_t rdOffset) const; diff --git a/src/PIM/Compiler/PimCompilerOptions.cpp b/src/PIM/Compiler/PimCompilerOptions.cpp index 7d1cb01..f2a4d60 100644 --- a/src/PIM/Compiler/PimCompilerOptions.cpp +++ b/src/PIM/Compiler/PimCompilerOptions.cpp @@ -32,6 +32,31 @@ llvm::cl::opt pimMemoryReport( llvm::cl::init(PimMemoryReportNone), llvm::cl::cat(OnnxMlirOptions)); +llvm::cl::opt pimConvLowering( + "pim-conv-lowering", + llvm::cl::desc("Convolution lowering strategy for PIM"), + llvm::cl::values(clEnumValN(PimConvLoweringAuto, "auto", "Select the Conv lowering strategy automatically")), + llvm::cl::values(clEnumValN(PimConvLoweringLegacy, "legacy", "Use the legacy explicit-im2col Conv lowering")), + llvm::cl::values(clEnumValN(PimConvLoweringDepthwise, "depthwise", "Force the depthwise-specialized Conv lowering")), + llvm::cl::values( + clEnumValN(PimConvLoweringPackedIm2Col, "packed-im2col", "Use explicit im2col with packed multi-position GEMM")), + llvm::cl::values(clEnumValN(PimConvLoweringStreamedPatch, + "streamed-patch", + "Use streamed/chunked im2col rows without multi-position packing")), + llvm::cl::values(clEnumValN(PimConvLoweringStreamedPacked, + "streamed-packed", + "Use streamed/chunked im2col rows with packed multi-position GEMM")), + llvm::cl::values(clEnumValN(PimConvLoweringOutputChannelTiled, + "output-channel-tiled", + "Force Conv lowering that relies on Gemm output-channel tiling")), + llvm::cl::values( + clEnumValN(PimConvLoweringInputKTiled, "input-k-tiled", "Force Conv lowering that relies on Gemm K tiling")), + llvm::cl::values(clEnumValN(PimConvLoweringTiled2D, + "tiled-2d", + "Force Conv lowering that relies on Gemm 2D K/C tiling")), + llvm::cl::init(PimConvLoweringAuto), + llvm::cl::cat(OnnxMlirOptions)); + llvm::cl::opt pimOnlyCodegen("pim-only-codegen", llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"), @@ -49,6 +74,23 @@ llvm::cl::opt useExperimentalConvImpl("use-experimental-conv-impl", llvm::cl::init(false), llvm::cl::cat(OnnxMlirOptions)); +llvm::cl::opt pimConvIm2colMaxElements( + "pim-conv-im2col-max-elements", + llvm::cl::desc("Maximum number of im2col elements to materialize globally for one Conv before streaming/chunking"), + llvm::cl::init(1ull << 20), + llvm::cl::cat(OnnxMlirOptions)); + +llvm::cl::opt pimConvStreamChunkPositions( + "pim-conv-stream-chunk-positions", + llvm::cl::desc("Maximum number of Conv output positions to materialize in one streamed chunk"), + llvm::cl::init(1024), + llvm::cl::cat(OnnxMlirOptions)); + +llvm::cl::opt pimReportConvLowering("pim-report-conv-lowering", + llvm::cl::desc("Emit a bounded Conv lowering report"), + llvm::cl::init(true), + llvm::cl::cat(OnnxMlirOptions)); + llvm::cl::opt pimEmitJson("pim-emit-json", llvm::cl::desc("Also emit per-core JSON instruction files alongside binary .pim files"), llvm::cl::init(false), diff --git a/src/PIM/Compiler/PimCompilerOptions.hpp b/src/PIM/Compiler/PimCompilerOptions.hpp index 3d90409..5fc77fb 100644 --- a/src/PIM/Compiler/PimCompilerOptions.hpp +++ b/src/PIM/Compiler/PimCompilerOptions.hpp @@ -30,19 +30,35 @@ typedef enum { PimMemoryReportFull = 2, } PimMemoryReportLevel; +typedef enum { + PimConvLoweringAuto = 0, + PimConvLoweringLegacy = 1, + PimConvLoweringDepthwise = 2, + PimConvLoweringPackedIm2Col = 3, + PimConvLoweringStreamedPatch = 4, + PimConvLoweringStreamedPacked = 5, + PimConvLoweringOutputChannelTiled = 6, + PimConvLoweringInputKTiled = 7, + PimConvLoweringTiled2D = 8, +} PimConvLoweringType; + extern llvm::cl::OptionCategory OnnxMlirOptions; extern llvm::cl::opt pimEmissionTarget; extern llvm::cl::opt pimMergeScheduler; extern llvm::cl::opt pimMemoryReport; +extern llvm::cl::opt pimConvLowering; extern llvm::cl::opt pimOnlyCodegen; extern llvm::cl::opt pimDisableMemoryCoalescing; extern llvm::cl::opt useExperimentalConvImpl; extern llvm::cl::opt pimEmitJson; +extern llvm::cl::opt pimReportConvLowering; extern llvm::cl::opt crossbarSize; extern llvm::cl::opt crossbarCountInCore; extern llvm::cl::opt coresCount; +extern llvm::cl::opt pimConvIm2colMaxElements; +extern llvm::cl::opt pimConvStreamChunkPositions; bool hasExplicitPimCoreCount(); void verifyExplicitPimCoreCount(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.cpp b/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.cpp index 27e2f93..3d82ff8 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.cpp @@ -19,9 +19,11 @@ using namespace mlir; namespace onnx_mlir { -bool isWeightLikeComputeOperand(Value value) { +static bool isWeightMaterializationValue(Value value, bool requireMatrixShape) { auto rankedType = dyn_cast(value.getType()); - if (!rankedType || !isMatrixShape(rankedType.getShape())) + if (!rankedType) + return false; + if (requireMatrixShape && !isMatrixShape(rankedType.getShape())) return false; llvm::SmallPtrSet visited; @@ -29,8 +31,14 @@ bool isWeightLikeComputeOperand(Value value) { while (auto* definingOp = value.getDefiningOp()) { if (!visited.insert(definingOp).second) return false; - if (isa(definingOp) || hasWeightAlways(definingOp)) + if (isa(definingOp) || hasWeightAlways(definingOp)) { + auto sourceType = dyn_cast(value.getType()); + if (!sourceType) + return false; + if (requireMatrixShape && !isMatrixShape(sourceType.getShape())) + return false; return true; + } if (auto extractSliceOp = dyn_cast(definingOp)) { value = extractSliceOp.getSource(); @@ -55,6 +63,8 @@ bool isWeightLikeComputeOperand(Value value) { return false; } +bool isWeightLikeComputeOperand(Value value) { return isWeightMaterializationValue(value, /*requireMatrixShape=*/true); } + FailureOr materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) { if (auto mapped = mapper.lookupOrNull(value)) return cast(mapped); @@ -91,7 +101,7 @@ FailureOr materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr continue; } - if (isWeightLikeComputeOperand(operand)) { + if (isWeightMaterializationValue(operand, /*requireMatrixShape=*/false)) { auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper); if (failed(clonedOperand)) return failure(); diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index 9888cf4..daa3943 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -26,6 +26,82 @@ static bool isUsedOnlyAsExplicitHostOperand(Value value) { }); } +static bool isMaterializableExternalTensorOp(Operation* op) { + return isa(op); +} + +//TODO REMOVE THIS UGLY FIX +//TODO: Remove this helper once compute_batch external tensor captures are +// fixed at the producer side. +// +// This function is a temporary SpatialToPim repair path. It clones selected +// external tensor producers, such as channel_receive and tensor view/slice ops, +// into the new pim.core_batch body when the old spat.compute_batch body refers +// to tensor values defined outside the batch. +// +// The real invariant should be stronger: +// +// A spat.compute_batch body must not capture external tensor values. +// Every tensor used inside the body must be either: +// - a compute_batch block argument, +// - defined inside the compute_batch body, +// - or a legal constant-like value. +// +// If this invariant is violated, the responsible producer, most likely merge +// schedule materialization, should emit verifier-clean Spatial IR instead of +// relying on SpatialToPim to clone external producer chains later. +// +// After that producer-side fix: +// 1. remove isMaterializableExternalTensorOp, +// 2. remove materializeExternalTensorValue, +// 3. make lowerComputeBatchOp emit a hard diagnostic for any unmapped external +// tensor operand, +// 4. keep/strengthen the Spatial verifier so the invalid capture is rejected +// before SpatialToPim. +// +// Be careful not to replace every external tensor capture with a normal +// compute_batch input blindly: host-backed tensors and explicit inter-core +// communication have different semantics. In particular, channel_receive-like +// values should be materialized through the communication model, not silently +// treated as host inputs. +static FailureOr materializeExternalTensorValue(IRRewriter& rewriter, + Location loc, + Block& oldBlock, + Value value, + IRMapping& mapper) { + if (mapper.contains(value)) + return mapper.lookup(value); + + if (!isa(value.getType())) + return value; + + Operation* definingOp = value.getDefiningOp(); + if (!definingOp || definingOp->hasTrait()) + return failure(); + + if (definingOp->getBlock() == &oldBlock) + return failure(); + + if (!isMaterializableExternalTensorOp(definingOp)) + return failure(); + + for (Value operand : definingOp->getOperands()) { + FailureOr materializedOperand = materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper); + if (succeeded(materializedOperand)) + mapper.map(operand, *materializedOperand); + } + + Operation* cloned = rewriter.clone(*definingOp, mapper); + for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults())) + mapper.map(originalResult, clonedResult); + + return mapper.lookup(value); +} + static FailureOr> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) { if (auto coreIdsAttr = computeBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) @@ -264,9 +340,18 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute Operation* definingOp = operand.getDefiningOp(); if (definingOp && definingOp->getBlock() == &oldBlock) continue; + if (definingOp && definingOp->hasTrait()) + continue; - return computeBatchOp.emitOpError( - "expected external tensor communication to be materialized in Spatial before batch lowering"); + if (succeeded(materializeExternalTensorValue(rewriter, loc, oldBlock, operand, mapper))) + continue; + + InFlightDiagnostic diagnostic = + computeBatchOp.emitOpError("expected external tensor communication to be materialized in Spatial before batch lowering"); + diagnostic << " while cloning nested op '" << op.getName() << "' tensor operand #" << operandIndex; + if (definingOp) + diagnostic << " from external producer '" << definingOp->getName() << "'"; + return diagnostic; } Operation* cloned = rewriter.clone(op, mapper); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 5100c2b..1510c1d 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -2018,6 +2018,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { struct PendingProjectedTransferDescriptor { ProjectedBatchInputKey inputKey; Operation* extractOp = nullptr; + RankedTensorType sourceType; RankedTensorType fragmentType; SmallVector fragmentShape; SmallVector, 16>, 8> fragmentOffsetsByLane; @@ -2029,6 +2030,20 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { DenseMap, ProducerKeyInfo> pending; + const auto isIdentityProjectedTransfer = [&](const PendingProjectedTransferDescriptor& descriptor) { + if (!descriptor.sourceType || descriptor.sourceType != descriptor.fragmentType) + return false; + + if (descriptor.fragmentOffsetsByLane.size() != 1) + return false; + + ArrayRef> fragments = descriptor.fragmentOffsetsByLane.front(); + if (fragments.size() != 1) + return false; + + return llvm::all_of(fragments.front(), [](int64_t offset) { return offset == 0; }); + }; + const auto appendEvaluatedFragments = [&](PendingProjectedTransferDescriptor& descriptor, unsigned targetLane, const AffineProjectedInputSliceMatch& match, @@ -2117,6 +2132,7 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { if (descriptor.fragmentOffsetsByLane.empty()) { descriptor.inputKey = {batch.getOperation(), static_cast(inputIndex)}; descriptor.extractOp = match->extract.getOperation(); + descriptor.sourceType = match->sourceType; descriptor.fragmentType = match->fragmentType; descriptor.fragmentShape = match->fragmentShape; descriptor.fragmentOffsetsByLane.resize(targetClass.isBatch ? targetClass.cpus.size() : 1); @@ -2132,7 +2148,8 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast(inputIndex)}; if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != match->extract.getOperation() - || descriptor.fragmentType != match->fragmentType || descriptor.fragmentShape != match->fragmentShape + || descriptor.sourceType != match->sourceType || descriptor.fragmentType != match->fragmentType + || descriptor.fragmentShape != match->fragmentShape || descriptor.loopLowerBounds.size() != match->loops.size()) { descriptor.invalid = true; continue; @@ -2175,6 +2192,8 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { continue; if (pendingDescriptor.fragmentOffsetsByLane.empty()) continue; + if (isIdentityProjectedTransfer(pendingDescriptor)) + continue; MaterializedClass& targetClass = state.classes[targetClassId]; ProjectedTransferDescriptor descriptor; @@ -2755,8 +2774,14 @@ FailureOr buildScalarSourceFanoutPlan(MaterializerState& if (*descriptor) { const ProjectedTransferDescriptor& projectedDescriptor = **descriptor; - if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType()) - return targetClass.op->emitError("scalar projected receive unexpectedly uses the full producer tensor type"); + if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType()) { + if (!fanoutPlan.ordinaryMessages) + fanoutPlan.ordinaryMessages = MessageVector {}; + fanoutPlan.ordinaryMessages->append( + receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds); + fanoutPlan.receivePlans.push_back(std::move(receivePlan)); + continue; + } receivePlan.receiveType = projectedDescriptor.payloadType; receivePlan.projectedExtractOp = projectedDescriptor.extractOp; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index d5015cc..1747a6d 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -42,7 +42,6 @@ namespace { using namespace onnx_mlir::compact_asm; using SpatCompute = spatial::SpatCompute; using SpatComputeBatch = spatial::SpatComputeBatch; -using spatial::getProducerValueRef; static std::optional getComputeCoreId(SpatCompute compute) { if (auto coreIdAttr = compute->getAttrOfType(onnx_mlir::kCoreIdAttrName)) { @@ -187,13 +186,23 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu SmallVector coreIds; }; + //TODO Used for report refactor + struct CollectorConcatRow { + uint64_t computeId = 0; + int32_t coreId = -1; + uint64_t operandCount = 0; + }; + uint64_t totalComputeOps = 0; uint64_t totalLogicalComputes = 0; uint64_t totalBatchComputeOps = 0; uint64_t totalInstructionCount = 0; uint64_t totalCrossbarCount = 0; uint64_t nextBatchId = 0; + //TODO Used for report refactor std::vector collectedData; + //TODO Used for report refactor + std::vector collectorConcatRows; auto getPerInstanceCrossbarCount = [&](Operation* op) -> uint64_t { return static_cast(spatial::collectDistinctCrossbarWeights(op).size()); @@ -206,7 +215,15 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu SmallVector coreIds; if (auto coreId = getComputeCoreId(spatCompute)) coreIds.push_back(*coreId); - collectedData.push_back({totalComputeOps++, 1, perInstanceCrossbarCount, numInst, false, coreIds}); + uint64_t computeId = totalComputeOps++; + collectedData.push_back({computeId, 1, perInstanceCrossbarCount, numInst, false, coreIds}); + uint64_t maxConcatOperands = 0; + spatCompute.getBody().walk([&](spatial::SpatConcatOp concatOp) { + maxConcatOperands = std::max(maxConcatOperands, concatOp.getInputs().size()); + }); + //TODO 128 is a magic number + if (maxConcatOperands >= 128 && !coreIds.empty()) + collectorConcatRows.push_back({computeId, coreIds.front(), maxConcatOperands}); totalLogicalComputes += 1; totalInstructionCount += numInst; totalCrossbarCount += perInstanceCrossbarCount; @@ -238,9 +255,17 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu {"Number of used crossbars", std::to_string(totalCrossbarCount) } }; printReportTotalsBlock(os, totalFields); - if (!collectedData.empty()) + if (!collectedData.empty() || !collectorConcatRows.empty()) os << "\n"; + if (!collectorConcatRows.empty()) { + os << "Collector concat materialization:\n"; + for (const CollectorConcatRow& row : collectorConcatRows) + os << "\tmaterialization_kind = single_collector_concat, compute = " << row.computeId + << ", concat_operand_count = " << row.operandCount << ", collector_core = " << row.coreId << "\n"; + os << "\n"; + } + sortReportEntriesByFirstCore(collectedData); for (uint64_t cI = 0; cI < totalComputeOps; ++cI) { diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp index 4fcef5d..3276323 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/MergeSchedulingAnalysis.cpp @@ -23,7 +23,10 @@ MergeSchedulerKind getSchedulerKind() { llvm_unreachable("unknown merge scheduler kind"); } -void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result, unsigned long crossbarCapacity) { +void verifySchedule(const ComputeGraph& graph, + const MergeScheduleResult& result, + unsigned long crossbarCapacity, + size_t processorCount) { llvm::DenseMap>> tasksByCpu; tasksByCpu.reserve(result.cpuToLastComputeMap.size()); @@ -79,7 +82,8 @@ void verifySchedule(const ComputeGraph& graph, const MergeScheduleResult& result Time earliestTargetStart = addOrMax(sourceStart, graph.nodes[edge.source].cost); if (sourceCpu != targetCpu) - earliestTargetStart = addOrMax(earliestTargetStart, edge.transferCost); + earliestTargetStart = addOrMax( + earliestTargetStart, getPeftTransferTime(edge.transferCost, sourceCpu, targetCpu, processorCount)); if (targetStart < earliestTargetStart) { std::string message = llvm::formatv("merge scheduling: dependency legality failed between tasks {0} and {1}", graph.nodes[edge.source].originalOrder, @@ -115,7 +119,10 @@ MergeScheduleResult MergeSchedulingAnalysis::run() { static_cast(crossbarCountInCore.getValue()), entryOp->getContext()}); } - verifySchedule(graph, schedule, static_cast(crossbarCountInCore.getValue())); + verifySchedule(graph, + schedule, + static_cast(crossbarCountInCore.getValue()), + options.processorCount); return schedule; } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp index 6c9f486..ca48f10 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp @@ -4,6 +4,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" +#include #include #include #include @@ -21,6 +22,63 @@ struct ScheduledTask { Time endTime = 0; }; +struct MeshModel { + size_t rows = 1; + size_t cols = 1; + long double averageDistance = 0.0L; + + static MeshModel infer(size_t processorCount) { + MeshModel model; + if (processorCount == 0) + return model; + + model.rows = static_cast(std::sqrt(static_cast(processorCount))); + if (model.rows == 0) + model.rows = 1; + while (model.rows > 1 && processorCount % model.rows != 0) + --model.rows; + model.cols = (processorCount + model.rows - 1) / model.rows; + + auto averageAxisDistance = [](size_t size) -> long double { + if (size <= 1) + return 0.0L; + return static_cast(size * size - 1) / (3.0L * static_cast(size)); + }; + model.averageDistance = averageAxisDistance(model.rows) + averageAxisDistance(model.cols); + return model; + } + + std::pair getCoord(size_t processor) const { + return {processor / cols, processor % cols}; + } + + size_t getDistance(size_t lhs, size_t rhs) const { + auto [lhsRow, lhsCol] = getCoord(lhs); + auto [rhsRow, rhsCol] = getCoord(rhs); + size_t rowDistance = lhsRow > rhsRow ? lhsRow - rhsRow : rhsRow - lhsRow; + size_t colDistance = lhsCol > rhsCol ? lhsCol - rhsCol : rhsCol - lhsCol; + return rowDistance + colDistance; + } + + Time scaleTransferCost(Time transferCost, size_t sourceProcessor, size_t targetProcessor) const { + if (sourceProcessor == targetProcessor || transferCost == 0) + return 0; + long double distance = static_cast(getDistance(sourceProcessor, targetProcessor)); + long double scale = averageDistance > 0.0L ? distance / averageDistance : 1.0L; + scale = std::max(0.25L, scale); + return static_cast