From 285773fa551c44a09483ef74c9836f89512193f6 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Mon, 4 May 2026 19:30:40 +0200 Subject: [PATCH] rework actually broken dcp merge + compute re-batching (still to refine) --- README.md | 2 +- src/PIM/Conversion/SpatialToPim/Patterns.cpp | 90 ++- .../SpatialToPim/SpatialToPimPass.cpp | 78 ++- .../DCPGraph/DCPAnalysis.cpp | 47 +- .../DCPGraph/DCPAnalysis.hpp | 2 + .../MergeComputeNodesPass.cpp | 600 +++++++++++++++--- validation/raptor.py | 17 +- validation/validate.py | 13 +- validation/validate_one.py | 20 +- 9 files changed, 696 insertions(+), 173 deletions(-) diff --git a/README.md b/README.md index 5017289..5443fd4 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,7 @@ validate.py \ --raptor-path ../cmake-build-release/Release/bin/onnx-mlir \ --onnx-include-dir ../onnx-mlir/include \ --operations-dir ./networks/yolo11n/depth_04 \ - --crossbar-size 2048 --crossbar-count 256 + --crossbar-size 2048 --crossbar-count 256 --core-count 1000 ``` Available networks under `validation/networks/`: `vgg16`, `yolo11n`. diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.cpp b/src/PIM/Conversion/SpatialToPim/Patterns.cpp index ec513ff..52228c9 100644 --- a/src/PIM/Conversion/SpatialToPim/Patterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns.cpp @@ -23,21 +23,42 @@ using namespace mlir; namespace onnx_mlir { namespace { +static std::optional getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) { + if (auto compute = dyn_cast(owner)) { + unsigned inputCount = compute.getInputs().size(); + if (inputCount == 0) + return std::nullopt; + + unsigned inputBegin = compute->getNumOperands() - inputCount; + if (operandNumber < inputBegin) + return std::nullopt; + return operandNumber - inputBegin; + } + + if (auto computeBatch = dyn_cast(owner)) { + unsigned inputCount = computeBatch.getInputs().size(); + if (inputCount == 0) + return std::nullopt; + + unsigned inputBegin = computeBatch->getNumOperands() - inputCount; + if (operandNumber < inputBegin) + return std::nullopt; + return operandNumber - inputBegin; + } + + return std::nullopt; +} + struct MoveExtractSliceIntoCompute final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override { - Location loc = extractSliceOp.getLoc(); - if (!isa(extractSliceOp->getParentOp())) return failure(); for (auto& uses : extractSliceOp->getUses()) { if (isa(uses.getOwner())) { - auto spatCompute = cast(uses.getOwner()); - if (spatCompute.getInputs().empty()) - return failure(); - if (uses.getOperandNumber() < spatCompute.getInputs().getBeginOperandIndex()) + if (!getDirectComputeInputIndex(uses.getOwner(), uses.getOperandNumber())) return failure(); } else if (isa_and_present(uses.getOwner()->getParentOp())) { @@ -50,7 +71,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePatterngetUses())) { if (auto spatCompute = dyn_cast(uses.getOwner())) { - auto BBArgIndex = uses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto inputIndex = getDirectComputeInputIndex(spatCompute, uses.getOperandNumber()); + if (!inputIndex) + return failure(); + auto BBArgIndex = *inputIndex; auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); if (BBArgValue.use_empty()) @@ -69,7 +93,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern(uses.getOwner())) { - auto BBArgIndex = uses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex(); + auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, uses.getOperandNumber()); + if (!inputIndex) + return failure(); + auto BBArgIndex = *inputIndex; auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); if (BBArgValue.use_empty()) @@ -165,8 +192,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern(constUsers)) { - - auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber()); + if (!inputIndex) + return failure(); + auto BBArgIndex = *inputIndex; auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) { @@ -183,8 +212,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern(constUsers)) { - - auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex(); + auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber()); + if (!inputIndex) + return failure(); + auto BBArgIndex = *inputIndex; auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) { @@ -201,7 +232,7 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePatterngetParentOfType()) { rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); @@ -240,8 +271,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern(constUsers)) { - - auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber()); + if (!inputIndex) + return failure(); + auto BBArgIndex = *inputIndex; auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); auto newConst = rewriter.clone(*constantOp); @@ -253,8 +286,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern(constUsers)) { - - auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex(); + auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber()); + if (!inputIndex) + return failure(); + auto BBArgIndex = *inputIndex; auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); auto newConst = rewriter.clone(*constantOp); @@ -265,11 +300,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePatterngetParentOfType()) { - if (!mapSpatComputeToConst.contains(parent)) { - rewriter.setInsertionPoint(&parent.getBody().front().front()); - auto newConst = rewriter.clone(*constantOp); + else if (auto parent = constUsers->getParentOfType()) { + if (!mapSpatComputeToConst.contains(parent)) { + rewriter.setInsertionPoint(&parent.getBody().front().front()); + auto newConst = rewriter.clone(*constantOp); mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)}); } constUses.set(mapSpatComputeToConst[parent.getOperation()]); @@ -285,9 +319,7 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePatterngetParentOp(); rewriter.eraseOp(constantOp); return success(); } @@ -333,7 +365,10 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern(argUser)) { - auto BBArgIndex = argUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto inputIndex = getDirectComputeInputIndex(spatCompute, argUses.getOperandNumber()); + if (!inputIndex) + return failure(); + auto BBArgIndex = *inputIndex; auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); @@ -347,7 +382,10 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern(argUser)) { - auto BBArgIndex = argUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex(); + auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, argUses.getOperandNumber()); + if (!inputIndex) + return failure(); + auto BBArgIndex = *inputIndex; auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 86f3bd5..5ec197b 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -11,20 +11,15 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" -#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/LogicalResult.h" -#include "llvm/Support/raw_os_ostream.h" #include -#include #include #include #include @@ -34,10 +29,8 @@ #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" -#include "src/Compiler/CompilerOptions.hpp" using namespace mlir; using namespace onnx_mlir; @@ -214,11 +207,12 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering"); return; } - SmallVector offsets = {rewriter.getIndexAttr(static_cast(rowIndex)), rewriter.getIndexAttr(0)}; + SmallVector offsets = {rewriter.getIndexAttr(static_cast(rowIndex)), + rewriter.getIndexAttr(0)}; SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - auto rowSlice = tensor::ExtractSliceOp::create( - rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides); + auto rowSlice = + tensor::ExtractSliceOp::create(rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides); replacements.push_back(rowSlice.getResult()); } @@ -263,19 +257,19 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, SmallPtrSet chainSet(reverseChain.begin(), reverseChain.end()); for (Operation& op : llvm::make_early_inc_range(block.without_terminator())) - if (!chainSet.contains(&op) - && !isa(op)) + if (!chainSet.contains(&op) && !isa(op)) return failure(); helperChain.assign(reverseChain.rbegin(), reverseChain.rend()); return success(); } -static bool inlineInputlessHelperComputeForBatchUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) { +static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) { if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1) return false; - if (!llvm::all_of(computeOp.getResult(0).getUsers(), - [](Operation* user) { return isa(user); })) + if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) { + return isa(user); + })) return false; Block& block = computeOp.getBody().front(); @@ -447,8 +441,7 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef sourceIndice auto hasStaticValues = [](ArrayRef values) { return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); }); }; - if (!hasStaticValues(extractSliceOp.getStaticOffsets()) - || !hasStaticValues(extractSliceOp.getStaticSizes()) + if (!hasStaticValues(extractSliceOp.getStaticOffsets()) || !hasStaticValues(extractSliceOp.getStaticSizes()) || !hasStaticValues(extractSliceOp.getStaticStrides())) return failure(); @@ -510,10 +503,8 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef sourceIndice return success(); } -static void cloneHelperChain(Value sourceValue, - ArrayRef helperChain, - IRRewriter& rewriter, - Value& clonedValue) { +static void +cloneHelperChain(Value sourceValue, ArrayRef helperChain, IRRewriter& rewriter, Value& clonedValue) { IRMapping mapping; mapping.map(sourceValue, sourceValue); clonedValue = sourceValue; @@ -734,7 +725,7 @@ void SpatialToPimPass::runOnOperation() { void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) { Location loc = computeOp->getLoc(); - if (inlineInputlessHelperComputeForBatchUsers(computeOp, rewriter)) + if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter)) return; SmallVector helperChain; @@ -835,7 +826,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter auto storedType = dyn_cast(yieldValue.getType()); if (!storedType) { - computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering"); + computeOp.emitOpError( + "has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering"); signalPassFailure(); return; } @@ -848,10 +840,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx; SmallVector destinationIndices; - if (failed(mapIndicesThroughHelperChain(sourceIndices, - concatReturnUse->concatShape, - concatReturnUse->helperChain, - destinationIndices))) { + if (failed(mapIndicesThroughHelperChain( + sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) { computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering"); signalPassFailure(); return; @@ -897,9 +887,12 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter rewriter.replaceOpWithNewOp(yieldOp); // Replace `spat.compute` with `pim.core` + SmallVector computeWeights; + if (!computeOp.getWeights().empty()) + computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end()); rewriter.setInsertionPointAfter(computeOp); auto coreOp = PimCoreOp::create( - rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId))); + rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId))); auto& coreOpBlocks = coreOp.getBody().getBlocks(); for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) if (!blockArg.use_empty()) @@ -933,15 +926,19 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc } SmallVector coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId); + SmallVector batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end()); + SmallVector batchInputs; + if (!computeBatchOp.getInputs().empty()) + batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end()); rewriter.setInsertionPointAfter(computeBatchOp); auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter, loc, rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()), - computeBatchOp.getWeights(), - computeBatchOp.getInputs()); + ValueRange(batchWeights), + ValueRange(batchInputs)); coreBatchOp.getProperties().setOperandSegmentSizes( - {static_cast(computeBatchOp.getWeights().size()), static_cast(computeBatchOp.getInputs().size())}); + {static_cast(batchWeights.size()), static_cast(batchInputs.size())}); coreBatchOp->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); SmallVector blockArgTypes; @@ -1124,13 +1121,13 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew std::string outputName = "output_" + std::to_string(index); rewriter.setInsertionPoint(returnOp.getParentOp()); memref::GlobalOp::create(rewriter, - returnOp.getLoc(), - rewriter.getStringAttr(outputName), - rewriter.getStringAttr("private"), - TypeAttr::get(memRefType), - {}, - {}, - {}); + returnOp.getLoc(), + rewriter.getStringAttr(outputName), + rewriter.getStringAttr("private"), + TypeAttr::get(memRefType), + {}, + {}, + {}); outputTensors.push_back( [memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value { auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName); @@ -1210,8 +1207,9 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri if (auto computeOp = dyn_cast(op)) { markOpToRemove(computeOp); - for (Value input : computeOp.getInputs()) - markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain); + if (!computeOp.getInputs().empty()) + for (Value input : computeOp.getInputs()) + markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain); return; } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp index 35534a4..332827d 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp @@ -184,14 +184,40 @@ std::optional getOriginalComputeInstance(Value value) { SmallVector collectComputeInstances(Operation* entryOp) { SmallVector instances; + auto isUsedAsWeightOnly = [](Operation* producerOp) { + if (producerOp->getNumResults() == 0) + return false; + for (Value result : producerOp->getResults()) { + if (result.use_empty()) + return false; + for (Operation* user : result.getUsers()) { + if (auto compute = dyn_cast(user)) { + if (!llvm::is_contained(compute.getWeights(), result)) + return false; + continue; + } + if (auto batch = dyn_cast(user)) { + if (!llvm::is_contained(batch.getWeights(), result)) + return false; + continue; + } + return false; + } + } + return true; + }; for (Region& region : entryOp->getRegions()) { for (Block& block : region) { for (Operation& op : block) { if (auto spatCompute = dyn_cast(&op)) { + if (isUsedAsWeightOnly(spatCompute.getOperation())) + continue; instances.push_back({spatCompute.getOperation(), 0, 1}); continue; } if (auto batch = dyn_cast(&op)) { + if (isUsedAsWeightOnly(batch.getOperation())) + continue; size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount()); for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex) instances.push_back(getBatchChunkForIndex(batch, chunkIndex)); @@ -582,10 +608,13 @@ DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRe } result.dominanceOrderCompute.reserve(computeInstances.size()); + llvm::DenseMap nextCpuSlot; for (auto [originalIndex, computeInstance] : llvm::enumerate(computeInstances)) { size_t cpu = originalComputeToCpu[originalIndex]; result.dominanceOrderCompute.push_back(computeInstance); result.computeToCpuMap[computeInstance] = cpu; + result.computeToCpuSlotMap[computeInstance] = nextCpuSlot[cpu]++; + result.computeToAestMap[computeInstance] = originalIndex; result.cpuToLastComputeMap[cpu] = computeInstance; } for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap) @@ -603,8 +632,12 @@ DCPAnalysisResult buildResultFromScheduledGraph(GraphDCP& graphDCP, ArrayRef(task.aest); + } result.cpuToLastComputeMap[cpu] = computeInstances[scheduledTasks.back().nodeIndex]; result.isLastComputeOfCpu.insert(computeInstances[scheduledTasks.back().nodeIndex]); } @@ -671,6 +704,16 @@ DCPAnalysisResult DCPAnalysis::run() { } } + if (coresCount.getValue() > 0) { + size_t schedulingCpuBudget = getSchedulingCpuBudget(); + bool needsExactScheduledBatches = llvm::any_of(computeInstances, [&](const ComputeInstance& instance) { + auto batch = dyn_cast(instance.op); + return batch && static_cast(batch.getLaneCount()) > schedulingCpuBudget; + }); + if (needsExactScheduledBatches) + return runLegacyDcp(computeInstances, edges, entryOp->getContext()); + } + if (dcpCriticalWindowSize.getValue() == 0) return runLegacyDcp(computeInstances, edges, entryOp->getContext()); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp index dd8adfa..0a54a6c 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp @@ -25,6 +25,8 @@ struct ComputeInstance { struct DCPAnalysisResult { std::vector dominanceOrderCompute; llvm::DenseMap computeToCpuMap; + llvm::DenseMap computeToCpuSlotMap; + llvm::DenseMap computeToAestMap; llvm::DenseSet isLastComputeOfCpu; llvm::DenseMap cpuToLastComputeMap; }; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 4dc9713..7cb69b0 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -1,3 +1,4 @@ +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" @@ -36,7 +37,6 @@ #include "DCPGraph/DCPAnalysis.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp" using namespace mlir; @@ -61,7 +61,7 @@ static size_t getFastPathCpuBudget() { static size_t getBatchChunkTargetCount(int32_t laneCount) { assert(laneCount > 0 && "laneCount must be positive"); - return std::min(static_cast(laneCount), std::max(1, getFastPathCpuBudget())); + return std::min(static_cast(laneCount), std::max(1, static_cast(getFastPathCpuBudget()))); } static ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) { @@ -129,6 +129,23 @@ std::optional getProducerValueRef(Value value) { static int32_t getPhysicalCoreId(size_t schedulerCpu) { return static_cast(schedulerCpu + 1); } +static size_t getMaterializationCpuBudget(size_t laneCount) { + if (coresCount.getValue() > 0) + return static_cast(coresCount.getValue()); + return std::max(1, laneCount); +} + +static SmallVector getMaterializedBatchCoreIds(size_t startCpu, size_t laneCount) { + size_t cpuBudget = getMaterializationCpuBudget(laneCount); + assert(laneCount <= cpuBudget && "materialized batch exceeds available CPUs"); + + SmallVector coreIds; + coreIds.reserve(laneCount); + for (size_t laneOffset = 0; laneOffset < laneCount; ++laneOffset) + coreIds.push_back(getPhysicalCoreId((startCpu + laneOffset) % cpuBudget)); + return coreIds; +} + static SmallVector getBatchCoreIds(Operation* op, size_t laneCount) { if (auto coreIdsAttr = op->getAttrOfType(onnx_mlir::kCoreIdAttrName)) return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); @@ -143,6 +160,14 @@ static std::optional getComputeCoreId(SpatCompute compute) { return std::nullopt; } +static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase"; + +static std::optional getComputeRebatchPhase(SpatCompute compute) { + if (auto phaseAttr = compute->getAttrOfType(kRebatchPhaseAttrName)) + return static_cast(phaseAttr.getInt()); + return std::nullopt; +} + static bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) { if (!lhs || !rhs) return false; @@ -152,6 +177,8 @@ static bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) { return false; if (lhs.getWeights().size() != rhs.getWeights().size()) return false; + if (getComputeRebatchPhase(lhs) != getComputeRebatchPhase(rhs)) + return false; if (!llvm::equal(lhs.getWeights(), rhs.getWeights())) return false; @@ -841,10 +868,14 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) { } for (auto compute : group) { + compute->removeAttr(kRebatchPhaseAttrName); consumed.insert(compute); rewriter.eraseOp(compute); } } + + for (auto compute : funcOp.getOps()) + compute->removeAttr(kRebatchPhaseAttrName); } struct ComputeMotifInfo { @@ -1329,8 +1360,9 @@ public: LazyInsertComputeResult( ComputeValueResults computeValueResults, + size_t producerCpu, std::function>(size_t, size_t)> channelInserter) - : computeResults(computeValueResults), channelInserter(channelInserter) {} + : computeResults(computeValueResults), producerCpu(producerCpu), channelInserter(channelInserter) {} struct ChannelOrLocalOp { Value data; @@ -1339,6 +1371,9 @@ public: }; ChannelOrLocalOp getAsChannelValueAndInsertSender(size_t resultIndex, size_t targetCpu) { + if (targetCpu == producerCpu) + return {computeResults.getOuter(resultIndex), false, {}}; + Value innerValue = computeResults.getInner(resultIndex); auto [channelInfo, channelSendInserter] = channelInserter(resultIndex, targetCpu); InsertPoint sendInsertPoint; @@ -1353,6 +1388,7 @@ public: private: ComputeValueResults computeResults; + size_t producerCpu = 0; std::function>(size_t, size_t)> channelInserter; }; @@ -1378,28 +1414,158 @@ public: mergeTriviallyConnectedComputes(getOperation()); emitMotifProfile(getOperation()); + func::FuncOp func = getOperation(); + Location loc = func.getLoc(); DCPAnalysisResult& analysisResult = getAnalysis().getResult(); - DenseSet materializedInstances; - for (size_t index = 0; index < analysisResult.dominanceOrderCompute.size(); ++index) { - ComputeInstance currentInstance = analysisResult.dominanceOrderCompute[index]; - if (!materializedInstances.insert(currentInstance).second) - continue; - - size_t cpu = analysisResult.computeToCpuMap.at(currentInstance); - if (auto batch = dyn_cast(currentInstance.op)) { - createNewBatchCompute(batch, currentInstance.laneStart, currentInstance.laneCount, cpu, analysisResult); - continue; - } - - auto scalarCompute = cast(currentInstance.op); - auto [newCompute, computeValueResults] = createNewComputeNode(scalarCompute, cpu, analysisResult); - newComputeNodeResults.insert({currentInstance, createLazyComputeResult(newCompute, computeValueResults, cpu)}); - } - DenseSet toEraseSet; for (ComputeInstance instance : analysisResult.dominanceOrderCompute) toEraseSet.insert(instance.op); + struct ScheduledTask { + ComputeInstance key; + Operation* sourceOp = nullptr; + size_t cpu = 0; + size_t slot = 0; + size_t order = 0; + }; + struct ChannelInfo { + int64_t channelId = -1; + int32_t sourceCoreId = -1; + int32_t targetCoreId = -1; + }; + struct CpuProgram { + SpatCompute op; + Block* block = nullptr; + DenseMap externalInputMap; + DenseMap weightToIndex; + }; + + auto getTaskInputs = [&](const ScheduledTask& task) { + SmallVector inputs; + if (auto compute = dyn_cast(task.sourceOp)) { + llvm::append_range(inputs, compute.getInputs()); + return inputs; + } + auto batch = cast(task.sourceOp); + for (uint32_t lane = task.key.laneStart; lane < task.key.laneStart + task.key.laneCount; ++lane) + if (!batch.getInputs().empty()) + inputs.push_back(batch.getInputs()[lane]); + return inputs; + }; + + auto getTaskWeights = [&](const ScheduledTask& task) { + SmallVector weights; + if (auto compute = dyn_cast(task.sourceOp)) { + llvm::append_range(weights, compute.getWeights()); + return weights; + } + auto batch = cast(task.sourceOp); + for (uint32_t lane = task.key.laneStart; lane < task.key.laneStart + task.key.laneCount; ++lane) + weights.push_back(batch.getWeights()[lane]); + return weights; + }; + + auto getTaskOutputValues = [&](const ScheduledTask& task) { + SmallVector outputs; + if (auto compute = dyn_cast(task.sourceOp)) { + for (Value result : compute.getResults()) + outputs.push_back(result); + return outputs; + } + auto batch = cast(task.sourceOp); + for (uint32_t lane = task.key.laneStart; lane < task.key.laneStart + task.key.laneCount; ++lane) + if (!batch.getOutputs().empty()) + outputs.push_back(batch.getOutputs()[lane]); + return outputs; + }; + + auto getTaskOutputTypes = [&](const ScheduledTask& task) { + SmallVector resultTypes; + if (auto compute = dyn_cast(task.sourceOp)) { + llvm::append_range(resultTypes, compute.getResultTypes()); + return resultTypes; + } + auto batch = cast(task.sourceOp); + for (uint32_t lane = task.key.laneStart; lane < task.key.laneStart + task.key.laneCount; ++lane) + if (!batch.getOutputs().empty()) + resultTypes.push_back(batch.getOutputs()[lane].getType()); + return resultTypes; + }; + + auto getTaskTemplateBlock = [&](const ScheduledTask& task) -> Block& { + if (auto compute = dyn_cast(task.sourceOp)) + return compute.getBody().front(); + return cast(task.sourceOp).getBody().front(); + }; + + auto appendUniqueValue = [](SmallVectorImpl& values, DenseSet& seen, Value value) { + if (seen.insert(value).second) + values.push_back(value); + }; + + DenseMap taskByKey; + DenseMap> tasksByCpu; + SmallVector orderedCpus; + DenseSet seenCpus; + DenseSet internalInputOpsToErase; + DenseMap isInternalInputOpCache; + size_t nextOrder = 0; + auto markCpuSeen = [&](size_t cpu) { + if (seenCpus.insert(cpu).second) + orderedCpus.push_back(cpu); + }; + for (ComputeInstance scheduledInstance : analysisResult.dominanceOrderCompute) { + size_t cpu = analysisResult.computeToCpuMap.at(scheduledInstance); + ScheduledTask task {scheduledInstance, + scheduledInstance.op, + cpu, + analysisResult.computeToCpuSlotMap.lookup(scheduledInstance), + nextOrder++}; + taskByKey[task.key] = task; + tasksByCpu[cpu].push_back(task); + markCpuSeen(cpu); + } + + llvm::sort(orderedCpus); + for (size_t cpu : orderedCpus) { + llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) { + if (lhs.slot != rhs.slot) + return lhs.slot < rhs.slot; + return lhs.order < rhs.order; + }); + } + + std::function isInternalInputOp = [&](Operation* op) { + auto it = isInternalInputOpCache.find(op); + if (it != isInternalInputOpCache.end()) + return it->second; + + auto extract = dyn_cast_or_null(op); + if (!extract) + return isInternalInputOpCache[op] = false; + + for (Value result : extract->getResults()) { + for (Operation* user : result.getUsers()) { + if (toEraseSet.contains(user)) + continue; + if (isInternalInputOp(user)) + continue; + return isInternalInputOpCache[op] = false; + } + } + return isInternalInputOpCache[op] = true; + }; + + auto collectInternalInputOps = [&](Value value) { + Operation* op = value.getDefiningOp(); + while (auto extract = dyn_cast_if_present(op)) { + if (isInternalInputOp(extract.getOperation())) + internalInputOpsToErase.insert(extract.getOperation()); + value = extract.getSource(); + op = value.getDefiningOp(); + } + }; + DenseSet externalUsersToMove; auto collectExternalUsers = [&](Operation* op, auto&& collectExternalUsers) -> void { if (!externalUsersToMove.insert(op).second) @@ -1413,28 +1579,294 @@ public: } }; - DenseSet erasedOps; - for (ComputeInstance instance : llvm::reverse(analysisResult.dominanceOrderCompute)) { - if (!erasedOps.insert(instance.op).second) - continue; - Operation* oldOp = instance.op; - if (Operation* newOp = oldToNewOpMap.lookup(oldOp)) { - for (unsigned i = 0; i < oldOp->getNumResults(); ++i) { - for (auto& use : llvm::make_early_inc_range(oldOp->getResult(i).getUses())) { + DenseMap>> remoteSendsByTask; + DenseMap>> remoteInputsByTask; + DenseMap> cpuExternalInputs; + DenseMap> cpuWeights; + DenseMap> cpuExternalOutputs; + DenseMap> seenExternalInputsByCpu; + DenseMap> seenWeightsByCpu; + + for (size_t cpu : orderedCpus) { + for (const ScheduledTask& task : tasksByCpu[cpu]) { + auto taskWeights = getTaskWeights(task); + for (Value weight : taskWeights) + appendUniqueValue(cpuWeights[cpu], seenWeightsByCpu[cpu], weight); + + auto taskInputs = getTaskInputs(task); + auto& remoteInputs = remoteInputsByTask[task.key]; + remoteInputs.resize(taskInputs.size()); + for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { + auto producerRef = getProducerValueRef(input); + if (producerRef) { + collectInternalInputOps(input); + auto producerIt = taskByKey.find(producerRef->instance); + if (producerIt != taskByKey.end()) { + if (producerIt->second.cpu != cpu) { + ChannelInfo info { + nextChannelId++, + getPhysicalCoreId(producerIt->second.cpu), + getPhysicalCoreId(cpu), + }; + remoteInputs[inputIndex] = info; + auto& perResultChannels = remoteSendsByTask[producerRef->instance]; + if (perResultChannels.empty()) + perResultChannels.resize(getTaskOutputTypes(producerIt->second).size()); + perResultChannels[producerRef->resultIndex].push_back(info); + } + continue; + } + } + appendUniqueValue(cpuExternalInputs[cpu], seenExternalInputsByCpu[cpu], input); + } + + auto taskOutputs = getTaskOutputValues(task); + for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) { + bool hasExternalUser = false; + for (auto& use : output.getUses()) { Operation* useOwner = use.getOwner(); - if (!toEraseSet.contains(useOwner)) { - use.assign(newOp->getResult(i)); - if (!isa(useOwner) && useOwner->isBeforeInBlock(newOp)) - collectExternalUsers(useOwner, collectExternalUsers); + if (toEraseSet.contains(useOwner)) + continue; + hasExternalUser = true; + if (!isa(useOwner)) + collectExternalUsers(useOwner, collectExternalUsers); + } + if (hasExternalUser) + cpuExternalOutputs[cpu].push_back({task.key, resultIndex}); + } + } + } + + auto returnOp = cast(func.getBody().front().getTerminator()); + IRRewriter rewriter(&getContext()); + DenseMap cpuPrograms; + DenseMap oldToNewExternalValueMap; + + for (size_t cpu : orderedCpus) { + SmallVector operands; + operands.reserve(cpuWeights[cpu].size() + cpuExternalInputs[cpu].size()); + llvm::append_range(operands, cpuWeights[cpu]); + llvm::append_range(operands, cpuExternalInputs[cpu]); + + SmallVector resultTypes; + resultTypes.reserve(cpuExternalOutputs[cpu].size()); + for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) { + ScheduledTask task = taskByKey.at(outputRef.instance); + resultTypes.push_back(getTaskOutputTypes(task)[outputRef.resultIndex]); + } + + rewriter.setInsertionPoint(returnOp); + auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands)); + newCompute.getProperties().setOperandSegmentSizes( + {static_cast(cpuWeights[cpu].size()), static_cast(cpuExternalInputs[cpu].size())}); + newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(getPhysicalCoreId(cpu))); + + SmallVector blockArgTypes; + SmallVector blockArgLocs; + blockArgTypes.reserve(cpuExternalInputs[cpu].size()); + blockArgLocs.reserve(cpuExternalInputs[cpu].size()); + for (Value input : cpuExternalInputs[cpu]) { + blockArgTypes.push_back(input.getType()); + blockArgLocs.push_back(loc); + } + Block* newBlock = + rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + + CpuProgram program; + program.op = newCompute; + program.block = newBlock; + for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[cpu])) + program.weightToIndex[weight] = weightIndex; + for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[cpu])) + program.externalInputMap[input] = newBlock->getArgument(inputIndex); + for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[cpu])) { + ScheduledTask task = taskByKey.at(outputRef.instance); + oldToNewExternalValueMap[getTaskOutputValues(task)[outputRef.resultIndex]] = newCompute.getResult(resultIndex); + } + cpuPrograms[cpu] = std::move(program); + } + + DenseMap> producedValuesByTask; + for (size_t cpu : orderedCpus) { + CpuProgram& program = cpuPrograms[cpu]; + IRRewriter cpuRewriter(&getContext()); + cpuRewriter.setInsertionPointToEnd(program.block); + + for (const ScheduledTask& task : tasksByCpu[cpu]) { + SmallVector taskInputs = getTaskInputs(task); + auto taskWeights = getTaskWeights(task); + Block& templateBlock = getTaskTemplateBlock(task); + + SmallVector resolvedInputs; + resolvedInputs.reserve(taskInputs.size()); + auto remoteInputsIt = remoteInputsByTask.find(task.key); + for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { + auto producerRef = getProducerValueRef(input); + if (producerRef) { + auto producerIt = taskByKey.find(producerRef->instance); + if (producerIt != taskByKey.end()) { + if (producerIt->second.cpu == cpu) { + auto producedIt = producedValuesByTask.find(producerRef->instance); + if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) { + task.sourceOp->emitOpError("missing local producer value during per-cpu merge materialization") + << " consumerCpu=" << cpu << " consumerSlot=" << task.slot + << " producerCpu=" << producerIt->second.cpu << " producerSlot=" << producerIt->second.slot + << " producerLaneStart=" << producerRef->instance.laneStart + << " producerLaneCount=" << producerRef->instance.laneCount; + signalPassFailure(); + return; + } + resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]); + continue; + } + const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex]; + auto receive = + spatial::SpatChannelReceiveOp::create(cpuRewriter, + loc, + input.getType(), + cpuRewriter.getI64IntegerAttr(channelInfo.channelId), + cpuRewriter.getI32IntegerAttr(channelInfo.sourceCoreId), + cpuRewriter.getI32IntegerAttr(channelInfo.targetCoreId)); + resolvedInputs.push_back(receive.getResult()); + continue; + } + } + resolvedInputs.push_back(program.externalInputMap.at(input)); + } + + SmallVector taskYieldValues; + cpuRewriter.setInsertionPointToEnd(program.block); + if (isa(task.sourceOp)) { + IRMapping mapper; + for (auto [argIndex, oldArg] : llvm::enumerate(templateBlock.getArguments())) + mapper.map(oldArg, resolvedInputs[argIndex]); + + for (Operation& op : templateBlock) { + if (auto yield = dyn_cast(&op)) { + for (Value yieldOperand : yield.getOperands()) + taskYieldValues.push_back(mapper.lookup(yieldOperand)); + continue; + } + + Operation* clonedOp = cpuRewriter.clone(op, mapper); + if (auto oldWeightedMvmOp = dyn_cast(&op)) { + auto newWeightedMvmOp = cast(clonedOp); + Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()]; + newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight)); + } + if (auto oldWeightedVmmOp = dyn_cast(&op)) { + auto newWeightedVmmOp = cast(clonedOp); + Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()]; + newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight)); } } } + else { + for (size_t laneOffset = 0; laneOffset < task.key.laneCount; ++laneOffset) { + IRMapping mapper; + if (templateBlock.getNumArguments() == 1) + mapper.map(templateBlock.getArgument(0), resolvedInputs[laneOffset]); + + for (Operation& op : templateBlock) { + if (auto yield = dyn_cast(&op)) { + for (Value yieldOperand : yield.getOperands()) + taskYieldValues.push_back(mapper.lookup(yieldOperand)); + continue; + } + + Operation* clonedOp = cpuRewriter.clone(op, mapper); + if (auto oldWeightedMvmOp = dyn_cast(&op)) { + if (oldWeightedMvmOp.getWeightIndex() != 0) { + task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0"); + signalPassFailure(); + return; + } + auto newWeightedMvmOp = cast(clonedOp); + newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset])); + } + if (auto oldWeightedVmmOp = dyn_cast(&op)) { + if (oldWeightedVmmOp.getWeightIndex() != 0) { + task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0"); + signalPassFailure(); + return; + } + auto newWeightedVmmOp = cast(clonedOp); + newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset])); + } + } + } + } + + producedValuesByTask[task.key] = taskYieldValues; + if (auto sendsIt = remoteSendsByTask.find(task.key); sendsIt != remoteSendsByTask.end()) { + for (auto [resultIndex, sendInfos] : llvm::enumerate(sendsIt->second)) { + if (sendInfos.empty()) + continue; + Value producedValue = taskYieldValues[resultIndex]; + for (const ChannelInfo& sendInfo : sendInfos) + spatial::SpatChannelSendOp::create(cpuRewriter, + loc, + cpuRewriter.getI64IntegerAttr(sendInfo.channelId), + cpuRewriter.getI32IntegerAttr(sendInfo.sourceCoreId), + cpuRewriter.getI32IntegerAttr(sendInfo.targetCoreId), + producedValue); + } + } } - oldOp->erase(); + + SmallVector yieldValues; + yieldValues.reserve(cpuExternalOutputs[cpu].size()); + for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) { + auto producedIt = producedValuesByTask.find(outputRef.instance); + if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) { + ScheduledTask task = taskByKey.at(outputRef.instance); + task.sourceOp->emitOpError("missing yielded external value during per-cpu merge materialization") + << " cpu=" << cpu << " slot=" << task.slot << " laneStart=" << outputRef.instance.laneStart; + signalPassFailure(); + return; + } + yieldValues.push_back(producedIt->second[outputRef.resultIndex]); + } + spatial::SpatYieldOp::create(cpuRewriter, loc, ValueRange(yieldValues)); + } + + for (auto [oldValue, newValue] : oldToNewExternalValueMap) { + for (auto& use : llvm::make_early_inc_range(oldValue.getUses())) + if (!toEraseSet.contains(use.getOwner())) + use.assign(newValue); + } + + DenseSet allOpsToErase = toEraseSet; + for (Operation* op : internalInputOpsToErase) + allOpsToErase.insert(op); + + SmallVector orderedOpsToErase; + for (Operation& op : func.getBody().front()) + if (allOpsToErase.contains(&op)) + orderedOpsToErase.push_back(&op); + for (Operation* op : llvm::reverse(orderedOpsToErase)) { + SmallVector remainingUsers; + for (Value result : op->getResults()) + for (Operation* user : result.getUsers()) + remainingUsers.push_back(user); + if (!remainingUsers.empty()) { + llvm::errs() << "[MergeComputeNodesPass] refusing to erase op with remaining uses: " << op->getName() << "\n"; + llvm::errs() << " erase-set: " << (allOpsToErase.contains(op) ? "yes" : "no") << "\n"; + op->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions()); + llvm::errs() << "\n"; + for (Operation* user : remainingUsers) { + llvm::errs() << " user: " << user->getName() + << " erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no") << "\n"; + user->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions()); + llvm::errs() << "\n"; + } + op->emitOpError("still has uses during per-cpu merge cleanup"); + signalPassFailure(); + return; + } + op->erase(); } - func::FuncOp func = getOperation(); - auto returnOp = cast(func.getBody().front().getTerminator()); SmallVector orderedUsersToMove; for (Operation& op : func.getBody().front()) { if (&op == returnOp.getOperation()) @@ -1445,9 +1877,13 @@ public: for (Operation* op : orderedUsersToMove) op->moveBefore(returnOp); - sinkChannelsIntoComputes(func, nextChannelId); rebatchEquivalentComputes(func, nextChannelId); compactScalarChannelRuns(func, nextChannelId); + if (!sortTopologically(&func.getBody().front())) { + func.emitOpError("failed to topologically order merged Spatial IR"); + signalPassFailure(); + return; + } dumpModule(cast(func->getParentOp()), "spatial1_dcp_merged"); generateReport(func, "spatial1_dcp_merged_report", analysisResult.cpuToLastComputeMap.size()); } @@ -1477,16 +1913,18 @@ private: Value resolvedInput = input; if (auto producerRef = getProducerValueRef(input)) { LazyInsertComputeResult& producer = newComputeNodeResults.at(producerRef->instance); - auto [channelVal, isChannel, channelInfo] = producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu); - (void) isChannel; - (void) channelVal; - resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter, - loc, - input.getType(), - rewriter.getI64IntegerAttr(channelInfo.channelId), - rewriter.getI32IntegerAttr(channelInfo.sourceCoreId), - rewriter.getI32IntegerAttr(channelInfo.targetCoreId)) - .getResult(); + auto [channelVal, isChannel, channelInfo] = + producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu); + if (isChannel) + resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter, + loc, + input.getType(), + rewriter.getI64IntegerAttr(channelInfo.channelId), + rewriter.getI32IntegerAttr(channelInfo.sourceCoreId), + rewriter.getI32IntegerAttr(channelInfo.targetCoreId)) + .getResult(); + else + resolvedInput = channelVal; } newComputeOperands.push_back(resolvedInput); @@ -1532,7 +1970,8 @@ private: uint32_t firstLane, uint32_t laneCount, size_t currentCpu, - const DCPAnalysisResult& analysisResult) { + const DCPAnalysisResult& analysisResult, + std::optional rebatchPhase = std::nullopt) { func::FuncOp func = getOperation(); auto loc = func.getLoc(); IRRewriter rewriter(&getContext()); @@ -1547,24 +1986,29 @@ private: for (uint32_t lane = firstLane; lane < firstLane + laneCount; ++lane) { weights.push_back(batch.getWeights()[lane]); - resultTypes.push_back(batch.getOutputs()[lane].getType()); + if (!batch.getOutputs().empty()) + resultTypes.push_back(batch.getOutputs()[lane].getType()); - Value input = batch.getInputs()[lane]; - Value resolvedInput = input; - if (auto producerRef = getProducerValueRef(input)) { - LazyInsertComputeResult& producer = newComputeNodeResults.at(producerRef->instance); - auto [channelVal, isChannel, channelInfo] = producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu); - (void) isChannel; - (void) channelVal; - resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter, - loc, - input.getType(), - rewriter.getI64IntegerAttr(channelInfo.channelId), - rewriter.getI32IntegerAttr(channelInfo.sourceCoreId), - rewriter.getI32IntegerAttr(channelInfo.targetCoreId)) - .getResult(); + if (!batch.getInputs().empty()) { + Value input = batch.getInputs()[lane]; + Value resolvedInput = input; + if (auto producerRef = getProducerValueRef(input)) { + LazyInsertComputeResult& producer = newComputeNodeResults.at(producerRef->instance); + auto [channelVal, isChannel, channelInfo] = + producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu); + if (isChannel) + resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter, + loc, + input.getType(), + rewriter.getI64IntegerAttr(channelInfo.channelId), + rewriter.getI32IntegerAttr(channelInfo.sourceCoreId), + rewriter.getI32IntegerAttr(channelInfo.targetCoreId)) + .getResult(); + else + resolvedInput = channelVal; + } + inputs.push_back(resolvedInput); } - inputs.push_back(resolvedInput); } Block& templateBlock = batch.getBody().front(); @@ -1574,11 +2018,17 @@ private: compute.getProperties().setOperandSegmentSizes( {static_cast(weights.size()), static_cast(inputs.size())}); compute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(getPhysicalCoreId(currentCpu))); + if (rebatchPhase) + compute->setAttr(kRebatchPhaseAttrName, rewriter.getI64IntegerAttr(*rebatchPhase)); - auto* newBlock = rewriter.createBlock( - &compute.getBody(), compute.getBody().end(), TypeRange {templateBlock.getArgument(0).getType()}, {loc}); + SmallVector blockArgTypes; + if (templateBlock.getNumArguments() == 1) + blockArgTypes.push_back(templateBlock.getArgument(0).getType()); + SmallVector blockArgLocs(templateBlock.getNumArguments(), loc); + auto* newBlock = rewriter.createBlock(&compute.getBody(), compute.getBody().end(), blockArgTypes, blockArgLocs); IRMapping mapper; - mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0)); + if (templateBlock.getNumArguments() == 1) + mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0)); rewriter.setInsertionPointToEnd(newBlock); for (Operation& op : templateBlock) rewriter.clone(op, mapper); @@ -1600,14 +2050,16 @@ private: ValueRange(weights), ValueRange(inputs)); rebatched->setAttr(onnx_mlir::kCoreIdAttrName, - rewriter.getDenseI32ArrayAttr(SmallVector(laneCount, getPhysicalCoreId(currentCpu)))); + rewriter.getDenseI32ArrayAttr(getMaterializedBatchCoreIds(currentCpu, laneCount))); - auto* newBlock = rewriter.createBlock(&rebatched.getBody(), - rebatched.getBody().end(), - TypeRange {templateBlock.getArgument(0).getType()}, - SmallVector(1, loc)); + SmallVector blockArgTypes; + if (templateBlock.getNumArguments() == 1) + blockArgTypes.push_back(templateBlock.getArgument(0).getType()); + SmallVector blockArgLocs(templateBlock.getNumArguments(), loc); + auto* newBlock = rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), blockArgTypes, blockArgLocs); IRMapping mapper; - mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0)); + if (templateBlock.getNumArguments() == 1) + mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0)); rewriter.setInsertionPointToEnd(newBlock); for (Operation& op : templateBlock) { if (auto yield = dyn_cast(&op)) { @@ -1621,7 +2073,7 @@ private: ComputeValueResults results; results.outerValues.assign(rebatched->result_begin(), rebatched->result_end()); results.innerValues = results.outerValues; - if (results.innerValues.empty()) + if (results.innerValues.empty() && yieldOp.getNumOperands() == 1) results.innerValues.push_back(yieldOp.getOperand(0)); newComputeNodeResults.insert({ ComputeInstance {batch.getOperation(), firstLane, laneCount}, @@ -1658,7 +2110,7 @@ private: channelInfo, insertVal}; return ret; }; - return LazyInsertComputeResult(computeValueResults, insertNew); + return LazyInsertComputeResult(computeValueResults, producerCpu, insertNew); } }; diff --git a/validation/raptor.py b/validation/raptor.py index b9c147b..506cd1f 100644 --- a/validation/raptor.py +++ b/validation/raptor.py @@ -1,4 +1,5 @@ import re +import shlex import subprocess from pathlib import Path from colorama import Fore, Style @@ -37,8 +38,12 @@ def _parse_pim_pass_timings(output_text): return pass_timings +def _format_command(cmd): + return shlex.join(str(arg) for arg in cmd) + + def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path, - crossbar_size, crossbar_count, cwd=None, reporter=None): + crossbar_size, crossbar_count, core_count=None, cwd=None, reporter=None): # Define the arguments, with the possibility to set crossbar size and count args = [ network_path, @@ -51,10 +56,18 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path, f"--crossbar-count={crossbar_count}", "--enable-timing", ] + if core_count is not None: + args.append(f"--core-count={core_count}") + + cmd = [str(raptor_onnx_path)] + [str(arg) for arg in args] + if reporter is not None: + reporter.log(f" Raptor command: {_format_command(cmd)}") + else: + print(f"Raptor command: {_format_command(cmd)}") try: output_text = run_command_with_reporter( - [str(raptor_onnx_path)] + [str(arg) for arg in args], + cmd, cwd=cwd, reporter=reporter, capture_output=True, diff --git a/validation/validate.py b/validation/validate.py index 357f6fa..e3c8d6c 100644 --- a/validation/validate.py +++ b/validation/validate.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import argparse -import shlex import signal import subprocess import sys @@ -11,12 +10,6 @@ from validate_one import ProgressReporter, clean_workspace_artifacts, validate_n from raptor import PIM_PASS_LABELS -def format_command(cmd): - if isinstance(cmd, (list, tuple)): - return shlex.join(str(arg) for arg in cmd) - return str(cmd) - - def format_return_status(returncode): if returncode < 0: signal_num = -returncode @@ -34,8 +27,6 @@ def print_validation_error(reporter, rel, exc): file=sys.stderr, flush=True) if isinstance(exc, subprocess.CalledProcessError): print(format_return_status(exc.returncode), file=sys.stderr, flush=True) - print("Retry command:", file=sys.stderr, flush=True) - print(format_command(exc.cmd), file=sys.stderr, flush=True) else: print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True) print("=" * 72, file=sys.stderr, flush=True) @@ -65,6 +56,8 @@ def main(): ap.add_argument("--threshold", type=float, default=1e-3, help="Max allowed diff per output element.") ap.add_argument("--crossbar-size", type=int, default=64) ap.add_argument("--crossbar-count", type=int, default=8) + ap.add_argument("--core-count", type=int, default=None, + help="Core count to pass to Raptor. If omitted, Raptor uses its default.") ap.add_argument("--clean", action="store_true", help="Remove generated validation artifacts under each model workspace and exit.") a = ap.parse_args() @@ -114,7 +107,7 @@ def main(): try: result = validate_network( onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir, - crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count, + crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count, core_count=a.core_count, threshold=a.threshold, reporter=reporter, model_index=index, diff --git a/validation/validate_one.py b/validation/validate_one.py index 4cc4ea9..cd5c1e7 100644 --- a/validation/validate_one.py +++ b/validation/validate_one.py @@ -1,4 +1,3 @@ -import argparse import json import numpy as np import subprocess @@ -258,7 +257,7 @@ def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1 def validate_network(network_onnx_path, raptor_path, onnx_include_dir, - simulator_dir, crossbar_size=64, crossbar_count=8, threshold=1e-3, + simulator_dir, crossbar_size=64, crossbar_count=8, core_count=None, threshold=1e-3, reporter=None, model_index=1, model_total=1): network_onnx_path = Path(network_onnx_path).resolve() raptor_path = Path(raptor_path).resolve() @@ -313,7 +312,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir, print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM") pim_pass_timings = compile_with_raptor( network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem, - crossbar_size, crossbar_count, + crossbar_size, crossbar_count, core_count=core_count, cwd=raptor_dir, reporter=reporter) print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}") reporter.advance() @@ -350,18 +349,3 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir, reporter.log("=" * 72) if owns_reporter: reporter.finish() - - -if __name__ == '__main__': - ap = argparse.ArgumentParser() - ap.add_argument("--network-onnx", required=True) - ap.add_argument("--raptor-path", required=True) - ap.add_argument("--onnx-include-dir", required=True) - a = ap.parse_args() - - simulator_dir = Path(__file__).parent.resolve() / ".." / "backend-simulators" / "pim" / "pim-simulator" - - passed = validate_network( - a.network_onnx, a.raptor_path, a.onnx_include_dir, simulator_dir - ) - raise SystemExit(0 if passed.passed else 1)