From 285773fa551c44a09483ef74c9836f89512193f6 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Mon, 4 May 2026 19:30:40 +0200 Subject: [PATCH 1/2] rework actually broken dcp merge + compute re-batching (still to refine) --- README.md | 2 +- src/PIM/Conversion/SpatialToPim/Patterns.cpp | 90 ++- .../SpatialToPim/SpatialToPimPass.cpp | 78 ++- .../DCPGraph/DCPAnalysis.cpp | 47 +- .../DCPGraph/DCPAnalysis.hpp | 2 + .../MergeComputeNodesPass.cpp | 600 +++++++++++++++--- validation/raptor.py | 17 +- validation/validate.py | 13 +- validation/validate_one.py | 20 +- 9 files changed, 696 insertions(+), 173 deletions(-) diff --git a/README.md b/README.md index 5017289..5443fd4 100644 --- a/README.md +++ b/README.md @@ -135,7 +135,7 @@ validate.py \ --raptor-path ../cmake-build-release/Release/bin/onnx-mlir \ --onnx-include-dir ../onnx-mlir/include \ --operations-dir ./networks/yolo11n/depth_04 \ - --crossbar-size 2048 --crossbar-count 256 + --crossbar-size 2048 --crossbar-count 256 --core-count 1000 ``` Available networks under `validation/networks/`: `vgg16`, `yolo11n`. diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.cpp b/src/PIM/Conversion/SpatialToPim/Patterns.cpp index ec513ff..52228c9 100644 --- a/src/PIM/Conversion/SpatialToPim/Patterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns.cpp @@ -23,21 +23,42 @@ using namespace mlir; namespace onnx_mlir { namespace { +static std::optional getDirectComputeInputIndex(Operation* owner, unsigned operandNumber) { + if (auto compute = dyn_cast(owner)) { + unsigned inputCount = compute.getInputs().size(); + if (inputCount == 0) + return std::nullopt; + + unsigned inputBegin = compute->getNumOperands() - inputCount; + if (operandNumber < inputBegin) + return std::nullopt; + return operandNumber - inputBegin; + } + + if (auto computeBatch = dyn_cast(owner)) { + unsigned inputCount = computeBatch.getInputs().size(); + if (inputCount == 0) + return std::nullopt; + + unsigned inputBegin = computeBatch->getNumOperands() - inputCount; + if (operandNumber < inputBegin) + return std::nullopt; + return operandNumber - inputBegin; + } + + return std::nullopt; +} + struct MoveExtractSliceIntoCompute final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(mlir::tensor::ExtractSliceOp extractSliceOp, PatternRewriter& rewriter) const override { - Location loc = extractSliceOp.getLoc(); - if (!isa(extractSliceOp->getParentOp())) return failure(); for (auto& uses : extractSliceOp->getUses()) { if (isa(uses.getOwner())) { - auto spatCompute = cast(uses.getOwner()); - if (spatCompute.getInputs().empty()) - return failure(); - if (uses.getOperandNumber() < spatCompute.getInputs().getBeginOperandIndex()) + if (!getDirectComputeInputIndex(uses.getOwner(), uses.getOperandNumber())) return failure(); } else if (isa_and_present(uses.getOwner()->getParentOp())) { @@ -50,7 +71,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePatterngetUses())) { if (auto spatCompute = dyn_cast(uses.getOwner())) { - auto BBArgIndex = uses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto inputIndex = getDirectComputeInputIndex(spatCompute, uses.getOperandNumber()); + if (!inputIndex) + return failure(); + auto BBArgIndex = *inputIndex; auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); if (BBArgValue.use_empty()) @@ -69,7 +93,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern(uses.getOwner())) { - auto BBArgIndex = uses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex(); + auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, uses.getOperandNumber()); + if (!inputIndex) + return failure(); + auto BBArgIndex = *inputIndex; auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); if (BBArgValue.use_empty()) @@ -165,8 +192,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern(constUsers)) { - - auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber()); + if (!inputIndex) + return failure(); + auto BBArgIndex = *inputIndex; auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) { @@ -183,8 +212,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern(constUsers)) { - - auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex(); + auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber()); + if (!inputIndex) + return failure(); + auto BBArgIndex = *inputIndex; auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) { @@ -201,7 +232,7 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePatterngetParentOfType()) { rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); @@ -240,8 +271,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern(constUsers)) { - - auto BBArgIndex = constUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto inputIndex = getDirectComputeInputIndex(spatCompute, constUses.getOperandNumber()); + if (!inputIndex) + return failure(); + auto BBArgIndex = *inputIndex; auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); auto newConst = rewriter.clone(*constantOp); @@ -253,8 +286,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePattern(constUsers)) { - - auto BBArgIndex = constUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex(); + auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, constUses.getOperandNumber()); + if (!inputIndex) + return failure(); + auto BBArgIndex = *inputIndex; auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); auto newConst = rewriter.clone(*constantOp); @@ -265,11 +300,10 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePatterngetParentOfType()) { - if (!mapSpatComputeToConst.contains(parent)) { - rewriter.setInsertionPoint(&parent.getBody().front().front()); - auto newConst = rewriter.clone(*constantOp); + else if (auto parent = constUsers->getParentOfType()) { + if (!mapSpatComputeToConst.contains(parent)) { + rewriter.setInsertionPoint(&parent.getBody().front().front()); + auto newConst = rewriter.clone(*constantOp); mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)}); } constUses.set(mapSpatComputeToConst[parent.getOperation()]); @@ -285,9 +319,7 @@ struct ArithConstToGlobalMemoryPattern final : OpRewritePatterngetParentOp(); rewriter.eraseOp(constantOp); return success(); } @@ -333,7 +365,10 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern(argUser)) { - auto BBArgIndex = argUses.getOperandNumber() - spatCompute.getInputs().getBeginOperandIndex(); + auto inputIndex = getDirectComputeInputIndex(spatCompute, argUses.getOperandNumber()); + if (!inputIndex) + return failure(); + auto BBArgIndex = *inputIndex; auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); @@ -347,7 +382,10 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern(argUser)) { - auto BBArgIndex = argUses.getOperandNumber() - spatComputeBatch.getInputs().getBeginOperandIndex(); + auto inputIndex = getDirectComputeInputIndex(spatComputeBatch, argUses.getOperandNumber()); + if (!inputIndex) + return failure(); + auto BBArgIndex = *inputIndex; auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex); rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 86f3bd5..5ec197b 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -11,20 +11,15 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" -#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/LogicalResult.h" -#include "llvm/Support/raw_os_ostream.h" #include -#include #include #include #include @@ -34,10 +29,8 @@ #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Pass/PIMPasses.h" -#include "src/Compiler/CompilerOptions.hpp" using namespace mlir; using namespace onnx_mlir; @@ -214,11 +207,12 @@ static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewrite extractRowsOp.emitOpError("requires ranked result tensors during Spatial-to-PIM lowering"); return; } - SmallVector offsets = {rewriter.getIndexAttr(static_cast(rowIndex)), rewriter.getIndexAttr(0)}; + SmallVector offsets = {rewriter.getIndexAttr(static_cast(rowIndex)), + rewriter.getIndexAttr(0)}; SmallVector sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(numCols)}; SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - auto rowSlice = tensor::ExtractSliceOp::create( - rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides); + auto rowSlice = + tensor::ExtractSliceOp::create(rewriter, extractRowsOp.getLoc(), outputType, input, offsets, sizes, strides); replacements.push_back(rowSlice.getResult()); } @@ -263,19 +257,19 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, SmallPtrSet chainSet(reverseChain.begin(), reverseChain.end()); for (Operation& op : llvm::make_early_inc_range(block.without_terminator())) - if (!chainSet.contains(&op) - && !isa(op)) + if (!chainSet.contains(&op) && !isa(op)) return failure(); helperChain.assign(reverseChain.rbegin(), reverseChain.rend()); return success(); } -static bool inlineInputlessHelperComputeForBatchUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) { +static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) { if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1) return false; - if (!llvm::all_of(computeOp.getResult(0).getUsers(), - [](Operation* user) { return isa(user); })) + if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) { + return isa(user); + })) return false; Block& block = computeOp.getBody().front(); @@ -447,8 +441,7 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef sourceIndice auto hasStaticValues = [](ArrayRef values) { return llvm::all_of(values, [](int64_t value) { return !ShapedType::isDynamic(value); }); }; - if (!hasStaticValues(extractSliceOp.getStaticOffsets()) - || !hasStaticValues(extractSliceOp.getStaticSizes()) + if (!hasStaticValues(extractSliceOp.getStaticOffsets()) || !hasStaticValues(extractSliceOp.getStaticSizes()) || !hasStaticValues(extractSliceOp.getStaticStrides())) return failure(); @@ -510,10 +503,8 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef sourceIndice return success(); } -static void cloneHelperChain(Value sourceValue, - ArrayRef helperChain, - IRRewriter& rewriter, - Value& clonedValue) { +static void +cloneHelperChain(Value sourceValue, ArrayRef helperChain, IRRewriter& rewriter, Value& clonedValue) { IRMapping mapping; mapping.map(sourceValue, sourceValue); clonedValue = sourceValue; @@ -734,7 +725,7 @@ void SpatialToPimPass::runOnOperation() { void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter& rewriter) { Location loc = computeOp->getLoc(); - if (inlineInputlessHelperComputeForBatchUsers(computeOp, rewriter)) + if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter)) return; SmallVector helperChain; @@ -835,7 +826,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter auto storedType = dyn_cast(yieldValue.getType()); if (!storedType) { - computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering"); + computeOp.emitOpError( + "has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering"); signalPassFailure(); return; } @@ -848,10 +840,8 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter sourceIndices[dim] = concatReturnUse->sliceOffsets[dim] + idx; SmallVector destinationIndices; - if (failed(mapIndicesThroughHelperChain(sourceIndices, - concatReturnUse->concatShape, - concatReturnUse->helperChain, - destinationIndices))) { + if (failed(mapIndicesThroughHelperChain( + sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) { computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering"); signalPassFailure(); return; @@ -897,9 +887,12 @@ void SpatialToPimPass::runOnComputeOp(spatial::SpatCompute computeOp, IRRewriter rewriter.replaceOpWithNewOp(yieldOp); // Replace `spat.compute` with `pim.core` + SmallVector computeWeights; + if (!computeOp.getWeights().empty()) + computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end()); rewriter.setInsertionPointAfter(computeOp); auto coreOp = PimCoreOp::create( - rewriter, loc, computeOp.getWeights(), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId))); + rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId))); auto& coreOpBlocks = coreOp.getBody().getBlocks(); for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) if (!blockArg.use_empty()) @@ -933,15 +926,19 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc } SmallVector coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId); + SmallVector batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end()); + SmallVector batchInputs; + if (!computeBatchOp.getInputs().empty()) + batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end()); rewriter.setInsertionPointAfter(computeBatchOp); auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter, loc, rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()), - computeBatchOp.getWeights(), - computeBatchOp.getInputs()); + ValueRange(batchWeights), + ValueRange(batchInputs)); coreBatchOp.getProperties().setOperandSegmentSizes( - {static_cast(computeBatchOp.getWeights().size()), static_cast(computeBatchOp.getInputs().size())}); + {static_cast(batchWeights.size()), static_cast(batchInputs.size())}); coreBatchOp->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); SmallVector blockArgTypes; @@ -1124,13 +1121,13 @@ void SpatialToPimPass::addResultBuffer(func::ReturnOp& returnOp, IRRewriter& rew std::string outputName = "output_" + std::to_string(index); rewriter.setInsertionPoint(returnOp.getParentOp()); memref::GlobalOp::create(rewriter, - returnOp.getLoc(), - rewriter.getStringAttr(outputName), - rewriter.getStringAttr("private"), - TypeAttr::get(memRefType), - {}, - {}, - {}); + returnOp.getLoc(), + rewriter.getStringAttr(outputName), + rewriter.getStringAttr("private"), + TypeAttr::get(memRefType), + {}, + {}, + {}); outputTensors.push_back( [memRefType, outputName, outRankedTensorType](IRRewriter& rewriter, Location loc) -> Value { auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, outputName); @@ -1210,8 +1207,9 @@ void SpatialToPimPass::replaceReturnOpOperands(func::ReturnOp& returnOp, IRRewri if (auto computeOp = dyn_cast(op)) { markOpToRemove(computeOp); - for (Value input : computeOp.getInputs()) - markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain); + if (!computeOp.getInputs().empty()) + for (Value input : computeOp.getInputs()) + markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain); return; } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp index 35534a4..332827d 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.cpp @@ -184,14 +184,40 @@ std::optional getOriginalComputeInstance(Value value) { SmallVector collectComputeInstances(Operation* entryOp) { SmallVector instances; + auto isUsedAsWeightOnly = [](Operation* producerOp) { + if (producerOp->getNumResults() == 0) + return false; + for (Value result : producerOp->getResults()) { + if (result.use_empty()) + return false; + for (Operation* user : result.getUsers()) { + if (auto compute = dyn_cast(user)) { + if (!llvm::is_contained(compute.getWeights(), result)) + return false; + continue; + } + if (auto batch = dyn_cast(user)) { + if (!llvm::is_contained(batch.getWeights(), result)) + return false; + continue; + } + return false; + } + } + return true; + }; for (Region& region : entryOp->getRegions()) { for (Block& block : region) { for (Operation& op : block) { if (auto spatCompute = dyn_cast(&op)) { + if (isUsedAsWeightOnly(spatCompute.getOperation())) + continue; instances.push_back({spatCompute.getOperation(), 0, 1}); continue; } if (auto batch = dyn_cast(&op)) { + if (isUsedAsWeightOnly(batch.getOperation())) + continue; size_t chunkCount = getBatchChunkTargetCount(batch.getLaneCount()); for (size_t chunkIndex = 0; chunkIndex < chunkCount; ++chunkIndex) instances.push_back(getBatchChunkForIndex(batch, chunkIndex)); @@ -582,10 +608,13 @@ DCPAnalysisResult buildResultFromVirtualGraph(const VirtualGraph& graph, ArrayRe } result.dominanceOrderCompute.reserve(computeInstances.size()); + llvm::DenseMap nextCpuSlot; for (auto [originalIndex, computeInstance] : llvm::enumerate(computeInstances)) { size_t cpu = originalComputeToCpu[originalIndex]; result.dominanceOrderCompute.push_back(computeInstance); result.computeToCpuMap[computeInstance] = cpu; + result.computeToCpuSlotMap[computeInstance] = nextCpuSlot[cpu]++; + result.computeToAestMap[computeInstance] = originalIndex; result.cpuToLastComputeMap[cpu] = computeInstance; } for (const auto& [cpu, lastCompute] : result.cpuToLastComputeMap) @@ -603,8 +632,12 @@ DCPAnalysisResult buildResultFromScheduledGraph(GraphDCP& graphDCP, ArrayRef(task.aest); + } result.cpuToLastComputeMap[cpu] = computeInstances[scheduledTasks.back().nodeIndex]; result.isLastComputeOfCpu.insert(computeInstances[scheduledTasks.back().nodeIndex]); } @@ -671,6 +704,16 @@ DCPAnalysisResult DCPAnalysis::run() { } } + if (coresCount.getValue() > 0) { + size_t schedulingCpuBudget = getSchedulingCpuBudget(); + bool needsExactScheduledBatches = llvm::any_of(computeInstances, [&](const ComputeInstance& instance) { + auto batch = dyn_cast(instance.op); + return batch && static_cast(batch.getLaneCount()) > schedulingCpuBudget; + }); + if (needsExactScheduledBatches) + return runLegacyDcp(computeInstances, edges, entryOp->getContext()); + } + if (dcpCriticalWindowSize.getValue() == 0) return runLegacyDcp(computeInstances, edges, entryOp->getContext()); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp index dd8adfa..0a54a6c 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/DCPGraph/DCPAnalysis.hpp @@ -25,6 +25,8 @@ struct ComputeInstance { struct DCPAnalysisResult { std::vector dominanceOrderCompute; llvm::DenseMap computeToCpuMap; + llvm::DenseMap computeToCpuSlotMap; + llvm::DenseMap computeToAestMap; llvm::DenseSet isLastComputeOfCpu; llvm::DenseMap cpuToLastComputeMap; }; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 4dc9713..7cb69b0 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -1,3 +1,4 @@ +#include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" @@ -36,7 +37,6 @@ #include "DCPGraph/DCPAnalysis.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" -#include "src/Accelerators/PIM/Dialect/Spatial/Channels.hpp" using namespace mlir; @@ -61,7 +61,7 @@ static size_t getFastPathCpuBudget() { static size_t getBatchChunkTargetCount(int32_t laneCount) { assert(laneCount > 0 && "laneCount must be positive"); - return std::min(static_cast(laneCount), std::max(1, getFastPathCpuBudget())); + return std::min(static_cast(laneCount), std::max(1, static_cast(getFastPathCpuBudget()))); } static ComputeInstance getBatchChunkForIndex(SpatComputeBatch batch, size_t chunkIndex) { @@ -129,6 +129,23 @@ std::optional getProducerValueRef(Value value) { static int32_t getPhysicalCoreId(size_t schedulerCpu) { return static_cast(schedulerCpu + 1); } +static size_t getMaterializationCpuBudget(size_t laneCount) { + if (coresCount.getValue() > 0) + return static_cast(coresCount.getValue()); + return std::max(1, laneCount); +} + +static SmallVector getMaterializedBatchCoreIds(size_t startCpu, size_t laneCount) { + size_t cpuBudget = getMaterializationCpuBudget(laneCount); + assert(laneCount <= cpuBudget && "materialized batch exceeds available CPUs"); + + SmallVector coreIds; + coreIds.reserve(laneCount); + for (size_t laneOffset = 0; laneOffset < laneCount; ++laneOffset) + coreIds.push_back(getPhysicalCoreId((startCpu + laneOffset) % cpuBudget)); + return coreIds; +} + static SmallVector getBatchCoreIds(Operation* op, size_t laneCount) { if (auto coreIdsAttr = op->getAttrOfType(onnx_mlir::kCoreIdAttrName)) return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); @@ -143,6 +160,14 @@ static std::optional getComputeCoreId(SpatCompute compute) { return std::nullopt; } +static constexpr StringLiteral kRebatchPhaseAttrName = "_pim_rebatch_phase"; + +static std::optional getComputeRebatchPhase(SpatCompute compute) { + if (auto phaseAttr = compute->getAttrOfType(kRebatchPhaseAttrName)) + return static_cast(phaseAttr.getInt()); + return std::nullopt; +} + static bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) { if (!lhs || !rhs) return false; @@ -152,6 +177,8 @@ static bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) { return false; if (lhs.getWeights().size() != rhs.getWeights().size()) return false; + if (getComputeRebatchPhase(lhs) != getComputeRebatchPhase(rhs)) + return false; if (!llvm::equal(lhs.getWeights(), rhs.getWeights())) return false; @@ -841,10 +868,14 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) { } for (auto compute : group) { + compute->removeAttr(kRebatchPhaseAttrName); consumed.insert(compute); rewriter.eraseOp(compute); } } + + for (auto compute : funcOp.getOps()) + compute->removeAttr(kRebatchPhaseAttrName); } struct ComputeMotifInfo { @@ -1329,8 +1360,9 @@ public: LazyInsertComputeResult( ComputeValueResults computeValueResults, + size_t producerCpu, std::function>(size_t, size_t)> channelInserter) - : computeResults(computeValueResults), channelInserter(channelInserter) {} + : computeResults(computeValueResults), producerCpu(producerCpu), channelInserter(channelInserter) {} struct ChannelOrLocalOp { Value data; @@ -1339,6 +1371,9 @@ public: }; ChannelOrLocalOp getAsChannelValueAndInsertSender(size_t resultIndex, size_t targetCpu) { + if (targetCpu == producerCpu) + return {computeResults.getOuter(resultIndex), false, {}}; + Value innerValue = computeResults.getInner(resultIndex); auto [channelInfo, channelSendInserter] = channelInserter(resultIndex, targetCpu); InsertPoint sendInsertPoint; @@ -1353,6 +1388,7 @@ public: private: ComputeValueResults computeResults; + size_t producerCpu = 0; std::function>(size_t, size_t)> channelInserter; }; @@ -1378,28 +1414,158 @@ public: mergeTriviallyConnectedComputes(getOperation()); emitMotifProfile(getOperation()); + func::FuncOp func = getOperation(); + Location loc = func.getLoc(); DCPAnalysisResult& analysisResult = getAnalysis().getResult(); - DenseSet materializedInstances; - for (size_t index = 0; index < analysisResult.dominanceOrderCompute.size(); ++index) { - ComputeInstance currentInstance = analysisResult.dominanceOrderCompute[index]; - if (!materializedInstances.insert(currentInstance).second) - continue; - - size_t cpu = analysisResult.computeToCpuMap.at(currentInstance); - if (auto batch = dyn_cast(currentInstance.op)) { - createNewBatchCompute(batch, currentInstance.laneStart, currentInstance.laneCount, cpu, analysisResult); - continue; - } - - auto scalarCompute = cast(currentInstance.op); - auto [newCompute, computeValueResults] = createNewComputeNode(scalarCompute, cpu, analysisResult); - newComputeNodeResults.insert({currentInstance, createLazyComputeResult(newCompute, computeValueResults, cpu)}); - } - DenseSet toEraseSet; for (ComputeInstance instance : analysisResult.dominanceOrderCompute) toEraseSet.insert(instance.op); + struct ScheduledTask { + ComputeInstance key; + Operation* sourceOp = nullptr; + size_t cpu = 0; + size_t slot = 0; + size_t order = 0; + }; + struct ChannelInfo { + int64_t channelId = -1; + int32_t sourceCoreId = -1; + int32_t targetCoreId = -1; + }; + struct CpuProgram { + SpatCompute op; + Block* block = nullptr; + DenseMap externalInputMap; + DenseMap weightToIndex; + }; + + auto getTaskInputs = [&](const ScheduledTask& task) { + SmallVector inputs; + if (auto compute = dyn_cast(task.sourceOp)) { + llvm::append_range(inputs, compute.getInputs()); + return inputs; + } + auto batch = cast(task.sourceOp); + for (uint32_t lane = task.key.laneStart; lane < task.key.laneStart + task.key.laneCount; ++lane) + if (!batch.getInputs().empty()) + inputs.push_back(batch.getInputs()[lane]); + return inputs; + }; + + auto getTaskWeights = [&](const ScheduledTask& task) { + SmallVector weights; + if (auto compute = dyn_cast(task.sourceOp)) { + llvm::append_range(weights, compute.getWeights()); + return weights; + } + auto batch = cast(task.sourceOp); + for (uint32_t lane = task.key.laneStart; lane < task.key.laneStart + task.key.laneCount; ++lane) + weights.push_back(batch.getWeights()[lane]); + return weights; + }; + + auto getTaskOutputValues = [&](const ScheduledTask& task) { + SmallVector outputs; + if (auto compute = dyn_cast(task.sourceOp)) { + for (Value result : compute.getResults()) + outputs.push_back(result); + return outputs; + } + auto batch = cast(task.sourceOp); + for (uint32_t lane = task.key.laneStart; lane < task.key.laneStart + task.key.laneCount; ++lane) + if (!batch.getOutputs().empty()) + outputs.push_back(batch.getOutputs()[lane]); + return outputs; + }; + + auto getTaskOutputTypes = [&](const ScheduledTask& task) { + SmallVector resultTypes; + if (auto compute = dyn_cast(task.sourceOp)) { + llvm::append_range(resultTypes, compute.getResultTypes()); + return resultTypes; + } + auto batch = cast(task.sourceOp); + for (uint32_t lane = task.key.laneStart; lane < task.key.laneStart + task.key.laneCount; ++lane) + if (!batch.getOutputs().empty()) + resultTypes.push_back(batch.getOutputs()[lane].getType()); + return resultTypes; + }; + + auto getTaskTemplateBlock = [&](const ScheduledTask& task) -> Block& { + if (auto compute = dyn_cast(task.sourceOp)) + return compute.getBody().front(); + return cast(task.sourceOp).getBody().front(); + }; + + auto appendUniqueValue = [](SmallVectorImpl& values, DenseSet& seen, Value value) { + if (seen.insert(value).second) + values.push_back(value); + }; + + DenseMap taskByKey; + DenseMap> tasksByCpu; + SmallVector orderedCpus; + DenseSet seenCpus; + DenseSet internalInputOpsToErase; + DenseMap isInternalInputOpCache; + size_t nextOrder = 0; + auto markCpuSeen = [&](size_t cpu) { + if (seenCpus.insert(cpu).second) + orderedCpus.push_back(cpu); + }; + for (ComputeInstance scheduledInstance : analysisResult.dominanceOrderCompute) { + size_t cpu = analysisResult.computeToCpuMap.at(scheduledInstance); + ScheduledTask task {scheduledInstance, + scheduledInstance.op, + cpu, + analysisResult.computeToCpuSlotMap.lookup(scheduledInstance), + nextOrder++}; + taskByKey[task.key] = task; + tasksByCpu[cpu].push_back(task); + markCpuSeen(cpu); + } + + llvm::sort(orderedCpus); + for (size_t cpu : orderedCpus) { + llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) { + if (lhs.slot != rhs.slot) + return lhs.slot < rhs.slot; + return lhs.order < rhs.order; + }); + } + + std::function isInternalInputOp = [&](Operation* op) { + auto it = isInternalInputOpCache.find(op); + if (it != isInternalInputOpCache.end()) + return it->second; + + auto extract = dyn_cast_or_null(op); + if (!extract) + return isInternalInputOpCache[op] = false; + + for (Value result : extract->getResults()) { + for (Operation* user : result.getUsers()) { + if (toEraseSet.contains(user)) + continue; + if (isInternalInputOp(user)) + continue; + return isInternalInputOpCache[op] = false; + } + } + return isInternalInputOpCache[op] = true; + }; + + auto collectInternalInputOps = [&](Value value) { + Operation* op = value.getDefiningOp(); + while (auto extract = dyn_cast_if_present(op)) { + if (isInternalInputOp(extract.getOperation())) + internalInputOpsToErase.insert(extract.getOperation()); + value = extract.getSource(); + op = value.getDefiningOp(); + } + }; + DenseSet externalUsersToMove; auto collectExternalUsers = [&](Operation* op, auto&& collectExternalUsers) -> void { if (!externalUsersToMove.insert(op).second) @@ -1413,28 +1579,294 @@ public: } }; - DenseSet erasedOps; - for (ComputeInstance instance : llvm::reverse(analysisResult.dominanceOrderCompute)) { - if (!erasedOps.insert(instance.op).second) - continue; - Operation* oldOp = instance.op; - if (Operation* newOp = oldToNewOpMap.lookup(oldOp)) { - for (unsigned i = 0; i < oldOp->getNumResults(); ++i) { - for (auto& use : llvm::make_early_inc_range(oldOp->getResult(i).getUses())) { + DenseMap>> remoteSendsByTask; + DenseMap>> remoteInputsByTask; + DenseMap> cpuExternalInputs; + DenseMap> cpuWeights; + DenseMap> cpuExternalOutputs; + DenseMap> seenExternalInputsByCpu; + DenseMap> seenWeightsByCpu; + + for (size_t cpu : orderedCpus) { + for (const ScheduledTask& task : tasksByCpu[cpu]) { + auto taskWeights = getTaskWeights(task); + for (Value weight : taskWeights) + appendUniqueValue(cpuWeights[cpu], seenWeightsByCpu[cpu], weight); + + auto taskInputs = getTaskInputs(task); + auto& remoteInputs = remoteInputsByTask[task.key]; + remoteInputs.resize(taskInputs.size()); + for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { + auto producerRef = getProducerValueRef(input); + if (producerRef) { + collectInternalInputOps(input); + auto producerIt = taskByKey.find(producerRef->instance); + if (producerIt != taskByKey.end()) { + if (producerIt->second.cpu != cpu) { + ChannelInfo info { + nextChannelId++, + getPhysicalCoreId(producerIt->second.cpu), + getPhysicalCoreId(cpu), + }; + remoteInputs[inputIndex] = info; + auto& perResultChannels = remoteSendsByTask[producerRef->instance]; + if (perResultChannels.empty()) + perResultChannels.resize(getTaskOutputTypes(producerIt->second).size()); + perResultChannels[producerRef->resultIndex].push_back(info); + } + continue; + } + } + appendUniqueValue(cpuExternalInputs[cpu], seenExternalInputsByCpu[cpu], input); + } + + auto taskOutputs = getTaskOutputValues(task); + for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) { + bool hasExternalUser = false; + for (auto& use : output.getUses()) { Operation* useOwner = use.getOwner(); - if (!toEraseSet.contains(useOwner)) { - use.assign(newOp->getResult(i)); - if (!isa(useOwner) && useOwner->isBeforeInBlock(newOp)) - collectExternalUsers(useOwner, collectExternalUsers); + if (toEraseSet.contains(useOwner)) + continue; + hasExternalUser = true; + if (!isa(useOwner)) + collectExternalUsers(useOwner, collectExternalUsers); + } + if (hasExternalUser) + cpuExternalOutputs[cpu].push_back({task.key, resultIndex}); + } + } + } + + auto returnOp = cast(func.getBody().front().getTerminator()); + IRRewriter rewriter(&getContext()); + DenseMap cpuPrograms; + DenseMap oldToNewExternalValueMap; + + for (size_t cpu : orderedCpus) { + SmallVector operands; + operands.reserve(cpuWeights[cpu].size() + cpuExternalInputs[cpu].size()); + llvm::append_range(operands, cpuWeights[cpu]); + llvm::append_range(operands, cpuExternalInputs[cpu]); + + SmallVector resultTypes; + resultTypes.reserve(cpuExternalOutputs[cpu].size()); + for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) { + ScheduledTask task = taskByKey.at(outputRef.instance); + resultTypes.push_back(getTaskOutputTypes(task)[outputRef.resultIndex]); + } + + rewriter.setInsertionPoint(returnOp); + auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands)); + newCompute.getProperties().setOperandSegmentSizes( + {static_cast(cpuWeights[cpu].size()), static_cast(cpuExternalInputs[cpu].size())}); + newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(getPhysicalCoreId(cpu))); + + SmallVector blockArgTypes; + SmallVector blockArgLocs; + blockArgTypes.reserve(cpuExternalInputs[cpu].size()); + blockArgLocs.reserve(cpuExternalInputs[cpu].size()); + for (Value input : cpuExternalInputs[cpu]) { + blockArgTypes.push_back(input.getType()); + blockArgLocs.push_back(loc); + } + Block* newBlock = + rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + + CpuProgram program; + program.op = newCompute; + program.block = newBlock; + for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[cpu])) + program.weightToIndex[weight] = weightIndex; + for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[cpu])) + program.externalInputMap[input] = newBlock->getArgument(inputIndex); + for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[cpu])) { + ScheduledTask task = taskByKey.at(outputRef.instance); + oldToNewExternalValueMap[getTaskOutputValues(task)[outputRef.resultIndex]] = newCompute.getResult(resultIndex); + } + cpuPrograms[cpu] = std::move(program); + } + + DenseMap> producedValuesByTask; + for (size_t cpu : orderedCpus) { + CpuProgram& program = cpuPrograms[cpu]; + IRRewriter cpuRewriter(&getContext()); + cpuRewriter.setInsertionPointToEnd(program.block); + + for (const ScheduledTask& task : tasksByCpu[cpu]) { + SmallVector taskInputs = getTaskInputs(task); + auto taskWeights = getTaskWeights(task); + Block& templateBlock = getTaskTemplateBlock(task); + + SmallVector resolvedInputs; + resolvedInputs.reserve(taskInputs.size()); + auto remoteInputsIt = remoteInputsByTask.find(task.key); + for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { + auto producerRef = getProducerValueRef(input); + if (producerRef) { + auto producerIt = taskByKey.find(producerRef->instance); + if (producerIt != taskByKey.end()) { + if (producerIt->second.cpu == cpu) { + auto producedIt = producedValuesByTask.find(producerRef->instance); + if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) { + task.sourceOp->emitOpError("missing local producer value during per-cpu merge materialization") + << " consumerCpu=" << cpu << " consumerSlot=" << task.slot + << " producerCpu=" << producerIt->second.cpu << " producerSlot=" << producerIt->second.slot + << " producerLaneStart=" << producerRef->instance.laneStart + << " producerLaneCount=" << producerRef->instance.laneCount; + signalPassFailure(); + return; + } + resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]); + continue; + } + const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex]; + auto receive = + spatial::SpatChannelReceiveOp::create(cpuRewriter, + loc, + input.getType(), + cpuRewriter.getI64IntegerAttr(channelInfo.channelId), + cpuRewriter.getI32IntegerAttr(channelInfo.sourceCoreId), + cpuRewriter.getI32IntegerAttr(channelInfo.targetCoreId)); + resolvedInputs.push_back(receive.getResult()); + continue; + } + } + resolvedInputs.push_back(program.externalInputMap.at(input)); + } + + SmallVector taskYieldValues; + cpuRewriter.setInsertionPointToEnd(program.block); + if (isa(task.sourceOp)) { + IRMapping mapper; + for (auto [argIndex, oldArg] : llvm::enumerate(templateBlock.getArguments())) + mapper.map(oldArg, resolvedInputs[argIndex]); + + for (Operation& op : templateBlock) { + if (auto yield = dyn_cast(&op)) { + for (Value yieldOperand : yield.getOperands()) + taskYieldValues.push_back(mapper.lookup(yieldOperand)); + continue; + } + + Operation* clonedOp = cpuRewriter.clone(op, mapper); + if (auto oldWeightedMvmOp = dyn_cast(&op)) { + auto newWeightedMvmOp = cast(clonedOp); + Value weight = taskWeights[oldWeightedMvmOp.getWeightIndex()]; + newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(weight)); + } + if (auto oldWeightedVmmOp = dyn_cast(&op)) { + auto newWeightedVmmOp = cast(clonedOp); + Value weight = taskWeights[oldWeightedVmmOp.getWeightIndex()]; + newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(weight)); } } } + else { + for (size_t laneOffset = 0; laneOffset < task.key.laneCount; ++laneOffset) { + IRMapping mapper; + if (templateBlock.getNumArguments() == 1) + mapper.map(templateBlock.getArgument(0), resolvedInputs[laneOffset]); + + for (Operation& op : templateBlock) { + if (auto yield = dyn_cast(&op)) { + for (Value yieldOperand : yield.getOperands()) + taskYieldValues.push_back(mapper.lookup(yieldOperand)); + continue; + } + + Operation* clonedOp = cpuRewriter.clone(op, mapper); + if (auto oldWeightedMvmOp = dyn_cast(&op)) { + if (oldWeightedMvmOp.getWeightIndex() != 0) { + task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0"); + signalPassFailure(); + return; + } + auto newWeightedMvmOp = cast(clonedOp); + newWeightedMvmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset])); + } + if (auto oldWeightedVmmOp = dyn_cast(&op)) { + if (oldWeightedVmmOp.getWeightIndex() != 0) { + task.sourceOp->emitOpError("batched per-cpu merge materialization expects lane-local weight index 0"); + signalPassFailure(); + return; + } + auto newWeightedVmmOp = cast(clonedOp); + newWeightedVmmOp.setWeightIndex(program.weightToIndex.at(taskWeights[laneOffset])); + } + } + } + } + + producedValuesByTask[task.key] = taskYieldValues; + if (auto sendsIt = remoteSendsByTask.find(task.key); sendsIt != remoteSendsByTask.end()) { + for (auto [resultIndex, sendInfos] : llvm::enumerate(sendsIt->second)) { + if (sendInfos.empty()) + continue; + Value producedValue = taskYieldValues[resultIndex]; + for (const ChannelInfo& sendInfo : sendInfos) + spatial::SpatChannelSendOp::create(cpuRewriter, + loc, + cpuRewriter.getI64IntegerAttr(sendInfo.channelId), + cpuRewriter.getI32IntegerAttr(sendInfo.sourceCoreId), + cpuRewriter.getI32IntegerAttr(sendInfo.targetCoreId), + producedValue); + } + } } - oldOp->erase(); + + SmallVector yieldValues; + yieldValues.reserve(cpuExternalOutputs[cpu].size()); + for (ProducerValueRef outputRef : cpuExternalOutputs[cpu]) { + auto producedIt = producedValuesByTask.find(outputRef.instance); + if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) { + ScheduledTask task = taskByKey.at(outputRef.instance); + task.sourceOp->emitOpError("missing yielded external value during per-cpu merge materialization") + << " cpu=" << cpu << " slot=" << task.slot << " laneStart=" << outputRef.instance.laneStart; + signalPassFailure(); + return; + } + yieldValues.push_back(producedIt->second[outputRef.resultIndex]); + } + spatial::SpatYieldOp::create(cpuRewriter, loc, ValueRange(yieldValues)); + } + + for (auto [oldValue, newValue] : oldToNewExternalValueMap) { + for (auto& use : llvm::make_early_inc_range(oldValue.getUses())) + if (!toEraseSet.contains(use.getOwner())) + use.assign(newValue); + } + + DenseSet allOpsToErase = toEraseSet; + for (Operation* op : internalInputOpsToErase) + allOpsToErase.insert(op); + + SmallVector orderedOpsToErase; + for (Operation& op : func.getBody().front()) + if (allOpsToErase.contains(&op)) + orderedOpsToErase.push_back(&op); + for (Operation* op : llvm::reverse(orderedOpsToErase)) { + SmallVector remainingUsers; + for (Value result : op->getResults()) + for (Operation* user : result.getUsers()) + remainingUsers.push_back(user); + if (!remainingUsers.empty()) { + llvm::errs() << "[MergeComputeNodesPass] refusing to erase op with remaining uses: " << op->getName() << "\n"; + llvm::errs() << " erase-set: " << (allOpsToErase.contains(op) ? "yes" : "no") << "\n"; + op->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions()); + llvm::errs() << "\n"; + for (Operation* user : remainingUsers) { + llvm::errs() << " user: " << user->getName() + << " erase-set=" << (allOpsToErase.contains(user) ? "yes" : "no") << "\n"; + user->print(llvm::errs(), mlir::OpPrintingFlags().skipRegions()); + llvm::errs() << "\n"; + } + op->emitOpError("still has uses during per-cpu merge cleanup"); + signalPassFailure(); + return; + } + op->erase(); } - func::FuncOp func = getOperation(); - auto returnOp = cast(func.getBody().front().getTerminator()); SmallVector orderedUsersToMove; for (Operation& op : func.getBody().front()) { if (&op == returnOp.getOperation()) @@ -1445,9 +1877,13 @@ public: for (Operation* op : orderedUsersToMove) op->moveBefore(returnOp); - sinkChannelsIntoComputes(func, nextChannelId); rebatchEquivalentComputes(func, nextChannelId); compactScalarChannelRuns(func, nextChannelId); + if (!sortTopologically(&func.getBody().front())) { + func.emitOpError("failed to topologically order merged Spatial IR"); + signalPassFailure(); + return; + } dumpModule(cast(func->getParentOp()), "spatial1_dcp_merged"); generateReport(func, "spatial1_dcp_merged_report", analysisResult.cpuToLastComputeMap.size()); } @@ -1477,16 +1913,18 @@ private: Value resolvedInput = input; if (auto producerRef = getProducerValueRef(input)) { LazyInsertComputeResult& producer = newComputeNodeResults.at(producerRef->instance); - auto [channelVal, isChannel, channelInfo] = producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu); - (void) isChannel; - (void) channelVal; - resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter, - loc, - input.getType(), - rewriter.getI64IntegerAttr(channelInfo.channelId), - rewriter.getI32IntegerAttr(channelInfo.sourceCoreId), - rewriter.getI32IntegerAttr(channelInfo.targetCoreId)) - .getResult(); + auto [channelVal, isChannel, channelInfo] = + producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu); + if (isChannel) + resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter, + loc, + input.getType(), + rewriter.getI64IntegerAttr(channelInfo.channelId), + rewriter.getI32IntegerAttr(channelInfo.sourceCoreId), + rewriter.getI32IntegerAttr(channelInfo.targetCoreId)) + .getResult(); + else + resolvedInput = channelVal; } newComputeOperands.push_back(resolvedInput); @@ -1532,7 +1970,8 @@ private: uint32_t firstLane, uint32_t laneCount, size_t currentCpu, - const DCPAnalysisResult& analysisResult) { + const DCPAnalysisResult& analysisResult, + std::optional rebatchPhase = std::nullopt) { func::FuncOp func = getOperation(); auto loc = func.getLoc(); IRRewriter rewriter(&getContext()); @@ -1547,24 +1986,29 @@ private: for (uint32_t lane = firstLane; lane < firstLane + laneCount; ++lane) { weights.push_back(batch.getWeights()[lane]); - resultTypes.push_back(batch.getOutputs()[lane].getType()); + if (!batch.getOutputs().empty()) + resultTypes.push_back(batch.getOutputs()[lane].getType()); - Value input = batch.getInputs()[lane]; - Value resolvedInput = input; - if (auto producerRef = getProducerValueRef(input)) { - LazyInsertComputeResult& producer = newComputeNodeResults.at(producerRef->instance); - auto [channelVal, isChannel, channelInfo] = producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu); - (void) isChannel; - (void) channelVal; - resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter, - loc, - input.getType(), - rewriter.getI64IntegerAttr(channelInfo.channelId), - rewriter.getI32IntegerAttr(channelInfo.sourceCoreId), - rewriter.getI32IntegerAttr(channelInfo.targetCoreId)) - .getResult(); + if (!batch.getInputs().empty()) { + Value input = batch.getInputs()[lane]; + Value resolvedInput = input; + if (auto producerRef = getProducerValueRef(input)) { + LazyInsertComputeResult& producer = newComputeNodeResults.at(producerRef->instance); + auto [channelVal, isChannel, channelInfo] = + producer.getAsChannelValueAndInsertSender(producerRef->resultIndex, currentCpu); + if (isChannel) + resolvedInput = spatial::SpatChannelReceiveOp::create(rewriter, + loc, + input.getType(), + rewriter.getI64IntegerAttr(channelInfo.channelId), + rewriter.getI32IntegerAttr(channelInfo.sourceCoreId), + rewriter.getI32IntegerAttr(channelInfo.targetCoreId)) + .getResult(); + else + resolvedInput = channelVal; + } + inputs.push_back(resolvedInput); } - inputs.push_back(resolvedInput); } Block& templateBlock = batch.getBody().front(); @@ -1574,11 +2018,17 @@ private: compute.getProperties().setOperandSegmentSizes( {static_cast(weights.size()), static_cast(inputs.size())}); compute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(getPhysicalCoreId(currentCpu))); + if (rebatchPhase) + compute->setAttr(kRebatchPhaseAttrName, rewriter.getI64IntegerAttr(*rebatchPhase)); - auto* newBlock = rewriter.createBlock( - &compute.getBody(), compute.getBody().end(), TypeRange {templateBlock.getArgument(0).getType()}, {loc}); + SmallVector blockArgTypes; + if (templateBlock.getNumArguments() == 1) + blockArgTypes.push_back(templateBlock.getArgument(0).getType()); + SmallVector blockArgLocs(templateBlock.getNumArguments(), loc); + auto* newBlock = rewriter.createBlock(&compute.getBody(), compute.getBody().end(), blockArgTypes, blockArgLocs); IRMapping mapper; - mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0)); + if (templateBlock.getNumArguments() == 1) + mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0)); rewriter.setInsertionPointToEnd(newBlock); for (Operation& op : templateBlock) rewriter.clone(op, mapper); @@ -1600,14 +2050,16 @@ private: ValueRange(weights), ValueRange(inputs)); rebatched->setAttr(onnx_mlir::kCoreIdAttrName, - rewriter.getDenseI32ArrayAttr(SmallVector(laneCount, getPhysicalCoreId(currentCpu)))); + rewriter.getDenseI32ArrayAttr(getMaterializedBatchCoreIds(currentCpu, laneCount))); - auto* newBlock = rewriter.createBlock(&rebatched.getBody(), - rebatched.getBody().end(), - TypeRange {templateBlock.getArgument(0).getType()}, - SmallVector(1, loc)); + SmallVector blockArgTypes; + if (templateBlock.getNumArguments() == 1) + blockArgTypes.push_back(templateBlock.getArgument(0).getType()); + SmallVector blockArgLocs(templateBlock.getNumArguments(), loc); + auto* newBlock = rewriter.createBlock(&rebatched.getBody(), rebatched.getBody().end(), blockArgTypes, blockArgLocs); IRMapping mapper; - mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0)); + if (templateBlock.getNumArguments() == 1) + mapper.map(templateBlock.getArgument(0), newBlock->getArgument(0)); rewriter.setInsertionPointToEnd(newBlock); for (Operation& op : templateBlock) { if (auto yield = dyn_cast(&op)) { @@ -1621,7 +2073,7 @@ private: ComputeValueResults results; results.outerValues.assign(rebatched->result_begin(), rebatched->result_end()); results.innerValues = results.outerValues; - if (results.innerValues.empty()) + if (results.innerValues.empty() && yieldOp.getNumOperands() == 1) results.innerValues.push_back(yieldOp.getOperand(0)); newComputeNodeResults.insert({ ComputeInstance {batch.getOperation(), firstLane, laneCount}, @@ -1658,7 +2110,7 @@ private: channelInfo, insertVal}; return ret; }; - return LazyInsertComputeResult(computeValueResults, insertNew); + return LazyInsertComputeResult(computeValueResults, producerCpu, insertNew); } }; diff --git a/validation/raptor.py b/validation/raptor.py index b9c147b..506cd1f 100644 --- a/validation/raptor.py +++ b/validation/raptor.py @@ -1,4 +1,5 @@ import re +import shlex import subprocess from pathlib import Path from colorama import Fore, Style @@ -37,8 +38,12 @@ def _parse_pim_pass_timings(output_text): return pass_timings +def _format_command(cmd): + return shlex.join(str(arg) for arg in cmd) + + def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path, - crossbar_size, crossbar_count, cwd=None, reporter=None): + crossbar_size, crossbar_count, core_count=None, cwd=None, reporter=None): # Define the arguments, with the possibility to set crossbar size and count args = [ network_path, @@ -51,10 +56,18 @@ def compile_with_raptor(network_path, raptor_onnx_path: Path, output_base: Path, f"--crossbar-count={crossbar_count}", "--enable-timing", ] + if core_count is not None: + args.append(f"--core-count={core_count}") + + cmd = [str(raptor_onnx_path)] + [str(arg) for arg in args] + if reporter is not None: + reporter.log(f" Raptor command: {_format_command(cmd)}") + else: + print(f"Raptor command: {_format_command(cmd)}") try: output_text = run_command_with_reporter( - [str(raptor_onnx_path)] + [str(arg) for arg in args], + cmd, cwd=cwd, reporter=reporter, capture_output=True, diff --git a/validation/validate.py b/validation/validate.py index 357f6fa..e3c8d6c 100644 --- a/validation/validate.py +++ b/validation/validate.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import argparse -import shlex import signal import subprocess import sys @@ -11,12 +10,6 @@ from validate_one import ProgressReporter, clean_workspace_artifacts, validate_n from raptor import PIM_PASS_LABELS -def format_command(cmd): - if isinstance(cmd, (list, tuple)): - return shlex.join(str(arg) for arg in cmd) - return str(cmd) - - def format_return_status(returncode): if returncode < 0: signal_num = -returncode @@ -34,8 +27,6 @@ def print_validation_error(reporter, rel, exc): file=sys.stderr, flush=True) if isinstance(exc, subprocess.CalledProcessError): print(format_return_status(exc.returncode), file=sys.stderr, flush=True) - print("Retry command:", file=sys.stderr, flush=True) - print(format_command(exc.cmd), file=sys.stderr, flush=True) else: print(f"{type(exc).__name__}: {exc}", file=sys.stderr, flush=True) print("=" * 72, file=sys.stderr, flush=True) @@ -65,6 +56,8 @@ def main(): ap.add_argument("--threshold", type=float, default=1e-3, help="Max allowed diff per output element.") ap.add_argument("--crossbar-size", type=int, default=64) ap.add_argument("--crossbar-count", type=int, default=8) + ap.add_argument("--core-count", type=int, default=None, + help="Core count to pass to Raptor. If omitted, Raptor uses its default.") ap.add_argument("--clean", action="store_true", help="Remove generated validation artifacts under each model workspace and exit.") a = ap.parse_args() @@ -114,7 +107,7 @@ def main(): try: result = validate_network( onnx_path, a.raptor_path, a.onnx_include_dir, simulator_dir, - crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count, + crossbar_size=a.crossbar_size, crossbar_count=a.crossbar_count, core_count=a.core_count, threshold=a.threshold, reporter=reporter, model_index=index, diff --git a/validation/validate_one.py b/validation/validate_one.py index 4cc4ea9..cd5c1e7 100644 --- a/validation/validate_one.py +++ b/validation/validate_one.py @@ -1,4 +1,3 @@ -import argparse import json import numpy as np import subprocess @@ -258,7 +257,7 @@ def validate_outputs(sim_arrays, runner_out_dir, outputs_descriptor, threshold=1 def validate_network(network_onnx_path, raptor_path, onnx_include_dir, - simulator_dir, crossbar_size=64, crossbar_count=8, threshold=1e-3, + simulator_dir, crossbar_size=64, crossbar_count=8, core_count=None, threshold=1e-3, reporter=None, model_index=1, model_total=1): network_onnx_path = Path(network_onnx_path).resolve() raptor_path = Path(raptor_path).resolve() @@ -313,7 +312,7 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir, print_stage(reporter, model_index, model_total, network_onnx_path.name, "Compile PIM") pim_pass_timings = compile_with_raptor( network_mlir_path, raptor_path, raptor_dir / network_onnx_path.stem, - crossbar_size, crossbar_count, + crossbar_size, crossbar_count, core_count=core_count, cwd=raptor_dir, reporter=reporter) print_info(reporter, f"PIM artifacts saved to {raptor_dir / 'pim'}") reporter.advance() @@ -350,18 +349,3 @@ def validate_network(network_onnx_path, raptor_path, onnx_include_dir, reporter.log("=" * 72) if owns_reporter: reporter.finish() - - -if __name__ == '__main__': - ap = argparse.ArgumentParser() - ap.add_argument("--network-onnx", required=True) - ap.add_argument("--raptor-path", required=True) - ap.add_argument("--onnx-include-dir", required=True) - a = ap.parse_args() - - simulator_dir = Path(__file__).parent.resolve() / ".." / "backend-simulators" / "pim" / "pim-simulator" - - passed = validate_network( - a.network_onnx, a.raptor_path, a.onnx_include_dir, simulator_dir - ) - raise SystemExit(0 if passed.passed else 1) From b2dc9c38b61aede4bc9c04e7325b5ce3ac69e524 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Wed, 6 May 2026 12:21:58 +0200 Subject: [PATCH 2/2] better spatial IR compaction with better custom syntax, scf.for and spat.map --- src/PIM/Common/PimCommon.hpp | 3 +- src/PIM/Compiler/PimCodeGen.cpp | 4 +- .../SpatialToPim/SpatialToPimPass.cpp | 128 +++- src/PIM/Dialect/Pim/Pim.td | 2 +- .../OpBufferizationInterfaces.cpp | 4 +- src/PIM/Dialect/Spatial/CMakeLists.txt | 1 + src/PIM/Dialect/Spatial/Spatial.td | 50 +- src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp | 704 +++++++++++++++--- src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp | 80 +- .../MergeComputeNodesPass.cpp | 149 +--- .../MergeComputeNodes/RegularOpCompaction.cpp | 577 ++++++++++++++ .../MergeComputeNodes/RegularOpCompaction.hpp | 14 + 12 files changed, 1442 insertions(+), 274 deletions(-) create mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp create mode 100644 src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.hpp diff --git a/src/PIM/Common/PimCommon.hpp b/src/PIM/Common/PimCommon.hpp index 59c5ee1..6880012 100644 --- a/src/PIM/Common/PimCommon.hpp +++ b/src/PIM/Common/PimCommon.hpp @@ -22,6 +22,7 @@ namespace onnx_mlir { -inline constexpr llvm::StringLiteral kCoreIdAttrName = "core_id"; +inline constexpr llvm::StringLiteral kCoreIdAttrName = "coreId"; +inline constexpr llvm::StringLiteral kCoreIdsAttrName = "coreIds"; } // namespace onnx_mlir diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index c6606b9..7a40d44 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -517,8 +517,8 @@ static SmallVector getUsedWeightIndices(pim::PimCoreOp coreOp) { } static SmallVector getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) { - auto coreIdsAttr = coreBatchOp->getAttrOfType(onnx_mlir::kCoreIdAttrName); - assert(coreIdsAttr && "pim.core_batch requires core_id array attribute"); + auto coreIdsAttr = coreBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName); + assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute"); return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); } diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 5ec197b..e0f44c8 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -111,7 +111,7 @@ static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& } static SmallVector getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) { - if (auto coreIdsAttr = computeBatchOp->getAttrOfType(onnx_mlir::kCoreIdAttrName)) + if (auto coreIdsAttr = computeBatchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); SmallVector coreIds; @@ -178,6 +178,43 @@ static void lowerChannelReceiveMany(spatial::SpatChannelReceiveManyOp receiveMan rewriter.replaceOp(receiveManyOp, ValueRange(replacements)); } +static void lowerChannelSendManyBatch(spatial::SpatChannelSendManyBatchOp sendManyBatchOp, + int32_t laneCount, + IRMapping& mapper, + IRRewriter& rewriter) { + auto targetCoreIds = sendManyBatchOp.getTargetCoreIds(); + for (auto [valueIndex, input] : llvm::enumerate(sendManyBatchOp.getInputs())) { + size_t metadataOffset = valueIndex * static_cast(laneCount); + auto targetSlice = targetCoreIds.slice(metadataOffset, laneCount); + pim::PimSendBatchOp::create(rewriter, + sendManyBatchOp.getLoc(), + mapper.lookup(input), + getTensorSizeInBytesAttr(rewriter, mapper.lookup(input)), + rewriter.getDenseI32ArrayAttr(targetSlice)); + } +} + +static void lowerChannelReceiveManyBatch(spatial::SpatChannelReceiveManyBatchOp receiveManyBatchOp, + int32_t laneCount, + IRMapping& mapper, + IRRewriter& rewriter) { + auto sourceCoreIds = receiveManyBatchOp.getSourceCoreIds(); + for (auto [valueIndex, output] : llvm::enumerate(receiveManyBatchOp.getOutputs())) { + size_t metadataOffset = valueIndex * static_cast(laneCount); + auto sourceSlice = sourceCoreIds.slice(metadataOffset, laneCount); + auto outputType = cast(output.getType()); + auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveManyBatchOp.getLoc(), outputType); + auto received = pim::PimReceiveBatchOp::create(rewriter, + receiveManyBatchOp.getLoc(), + outputBuffer.getType(), + outputBuffer, + getTensorSizeInBytesAttr(rewriter, output), + rewriter.getDenseI32ArrayAttr(sourceSlice)) + .getOutput(); + mapper.map(output, received); + } +} + static void lowerExtractRows(spatial::SpatExtractRowsOp extractRowsOp, IRRewriter& rewriter) { Value input = extractRowsOp.getInput(); RankedTensorType inputType; @@ -226,6 +263,56 @@ static void lowerConcat(spatial::SpatConcatOp concatOp, IRRewriter& rewriter) { rewriter.replaceOp(concatOp, concatenated); } +static void lowerRemainingSpatialMathOps(func::FuncOp funcOp, IRRewriter& rewriter) { + SmallVector wvmmOps; + funcOp.walk([&](spatial::SpatWeightedVMMOp wvmmOp) { + if (wvmmOp->getParentOfType() || wvmmOp->getParentOfType()) + wvmmOps.push_back(wvmmOp); + }); + + for (auto wvmmOp : wvmmOps) { + rewriter.setInsertionPoint(wvmmOp); + auto outputType = cast(wvmmOp.getOutput().getType()); + Value outputBuffer = createEmptyTensorFromShaped(rewriter, wvmmOp.getLoc(), outputType).getResult(); + rewriter.replaceOpWithNewOp(wvmmOp, + wvmmOp.getOutput().getType(), + rewriter.getI32IntegerAttr(wvmmOp.getWeightIndex()), + wvmmOp.getInput(), + outputBuffer); + } +} + +static void expandMapOps(func::FuncOp funcOp, IRRewriter& rewriter) { + SmallVector mapOps; + funcOp.walk([&](spatial::SpatMapOp mapOp) { mapOps.push_back(mapOp); }); + + for (auto mapOp : mapOps) { + Block& body = mapOp.getBody().front(); + auto yieldOp = cast(body.getTerminator()); + + SmallVector replacements; + replacements.reserve(mapOp.getInputs().size()); + rewriter.setInsertionPoint(mapOp); + for (Value input : mapOp.getInputs()) { + IRMapping mapping; + mapping.map(body.getArgument(0), input); + + Value replacement = input; + for (Operation& op : body.without_terminator()) { + Operation* cloned = rewriter.clone(op, mapping); + for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults())) + mapping.map(originalResult, clonedResult); + rewriter.setInsertionPointAfter(cloned); + } + + replacement = mapping.lookupOrDefault(yieldOp.getOperand(0)); + replacements.push_back(replacement); + } + + rewriter.replaceOp(mapOp, replacements); + } +} + static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp, SmallVectorImpl& helperChain, bool requireReturnUse = true) { @@ -551,6 +638,7 @@ void SpatialToPimPass::runOnOperation() { func::FuncOp funcOp = *entryFunc; IRRewriter rewriter(&getContext()); + expandMapOps(funcOp, rewriter); ConversionTarget target(*ctx); target.addLegalDialect coreOps; + funcOp.walk([&](pim::PimCoreOp coreOp) { coreOps.push_back(coreOp); }); + for (auto coreOp : coreOps) { + if (failed(applyPatternsGreedily(coreOp.getOperation(), frozenCoreBodyPatterns))) { + signalPassFailure(); + return; + } + } + + SmallVector coreBatchOps; + funcOp.walk([&](pim::PimCoreBatchOp coreBatchOp) { coreBatchOps.push_back(coreBatchOp); }); + for (auto coreBatchOp : coreBatchOps) { + if (failed(applyPatternsGreedily(coreBatchOp.getOperation(), frozenCoreBodyPatterns))) { + signalPassFailure(); + return; + } + } + } + + lowerRemainingSpatialMathOps(funcOp, rewriter); + RewritePatternSet channelPatterns(ctx); populateWithGenerated(channelPatterns); if (failed(applyPatternsGreedily(funcOp, std::move(channelPatterns)))) { @@ -939,7 +1053,7 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc ValueRange(batchInputs)); coreBatchOp.getProperties().setOperandSegmentSizes( {static_cast(batchWeights.size()), static_cast(batchInputs.size())}); - coreBatchOp->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); + coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); SmallVector blockArgTypes; SmallVector blockArgLocs; @@ -1000,6 +1114,11 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc continue; } + if (auto sendManyBatchOp = dyn_cast(op)) { + lowerChannelSendManyBatch(sendManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter); + continue; + } + if (auto receiveBatchOp = dyn_cast(op)) { auto outputType = cast(receiveBatchOp.getOutput().getType()); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType); @@ -1014,6 +1133,11 @@ void SpatialToPimPass::runOnComputeBatchOp(spatial::SpatComputeBatch computeBatc continue; } + if (auto receiveManyBatchOp = dyn_cast(op)) { + lowerChannelReceiveManyBatch(receiveManyBatchOp, computeBatchOp.getLaneCount(), mapper, rewriter); + continue; + } + if (auto toTensorOp = dyn_cast(op)) { if (isa_and_present(toTensorOp.getBuffer().getDefiningOp())) { Operation* cloned = rewriter.clone(op, mapper); diff --git a/src/PIM/Dialect/Pim/Pim.td b/src/PIM/Dialect/Pim/Pim.td index 7a2b7d8..3ecc353 100644 --- a/src/PIM/Dialect/Pim/Pim.td +++ b/src/PIM/Dialect/Pim/Pim.td @@ -39,7 +39,7 @@ def PimCoreOp : PimOp<"core", [SingleBlock, IsolatedFromAbove]> { }]; } -def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, AttrSizedOperandSegments]> { +def PimCoreBatchOp : PimOp<"core_batch", [SingleBlock, IsolatedFromAbove, AttrSizedOperandSegments]> { let summary = "Execute equivalent batched core bodies"; let regions = (region SizedRegion<1>:$body); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp index c9a6769..481d0af 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/OpBufferizationInterfaces.cpp @@ -257,8 +257,8 @@ struct CoreBatchOpInterface : BufferizableOpInterface::ExternalModel(weights.size()), static_cast(inputs.size())}); - if (auto coreIdsAttr = coreBatchOp->getAttr(onnx_mlir::kCoreIdAttrName)) - newOp->setAttr(onnx_mlir::kCoreIdAttrName, coreIdsAttr); + if (auto coreIdsAttr = coreBatchOp->getAttr(onnx_mlir::kCoreIdsAttrName)) + newOp->setAttr(onnx_mlir::kCoreIdsAttrName, coreIdsAttr); rewriter.inlineRegionBefore(coreBatchOp.getBody(), newOp.getBody(), newOp.getBody().begin()); for (Block& block : newOp.getBody()) diff --git a/src/PIM/Dialect/Spatial/CMakeLists.txt b/src/PIM/Dialect/Spatial/CMakeLists.txt index 50f5ce0..6f6a7ff 100644 --- a/src/PIM/Dialect/Spatial/CMakeLists.txt +++ b/src/PIM/Dialect/Spatial/CMakeLists.txt @@ -8,6 +8,7 @@ add_pim_library(SpatialOps SpatialOpsVerify.cpp SpatialOpsCanonicalization.cpp Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp + Transforms/MergeComputeNodes/RegularOpCompaction.cpp Transforms/MergeComputeNodes/DCPGraph/Graph.cpp Transforms/MergeComputeNodes/DCPGraph/GraphDebug.cpp Transforms/MergeComputeNodes/DCPGraph/GraphSupport.cpp diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 238472f..9a074a6 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -102,6 +102,23 @@ def SpatConcatOp : SpatOp<"concat", []> { let hasCustomAssemblyFormat = 1; } +def SpatMapOp : SpatOp<"map", [SingleBlock]> { + let summary = "Apply the same lane-local region to many independent tensors"; + + let arguments = (ins + Variadic:$inputs + ); + + let results = (outs + Variadic:$outputs + ); + + let regions = (region SizedRegion<1>:$body); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + //===----------------------------------------------------------------------===// // Communication //===----------------------------------------------------------------------===// @@ -184,6 +201,20 @@ def SpatChannelSendBatchOp : SpatOp<"channel_send_batch", []> { let hasCustomAssemblyFormat = 1; } +def SpatChannelSendManyBatchOp : SpatOp<"channel_send_many_batch", []> { + let summary = "Send multiple per-lane tensors through logical channels in a batch body"; + + let arguments = (ins + DenseI64ArrayAttr:$channelIds, + DenseI32ArrayAttr:$sourceCoreIds, + DenseI32ArrayAttr:$targetCoreIds, + Variadic:$inputs + ); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> { let summary = "Receive a per-lane tensor through logical channels in a batch body"; @@ -201,11 +232,28 @@ def SpatChannelReceiveBatchOp : SpatOp<"channel_receive_batch", []> { let hasCustomAssemblyFormat = 1; } +def SpatChannelReceiveManyBatchOp : SpatOp<"channel_receive_many_batch", []> { + let summary = "Receive multiple per-lane tensors through logical channels in a batch body"; + + let arguments = (ins + DenseI64ArrayAttr:$channelIds, + DenseI32ArrayAttr:$sourceCoreIds, + DenseI32ArrayAttr:$targetCoreIds + ); + + let results = (outs + Variadic:$outputs + ); + + let hasVerifier = 1; + let hasCustomAssemblyFormat = 1; +} + //===----------------------------------------------------------------------===// // Math //===----------------------------------------------------------------------===// -def SpatWeightedVMMOp : SpatOp<"Wvmm", []> { +def SpatWeightedVMMOp : SpatOp<"wvmm", []> { let summary = "Vector-matrix multiplication within a weighted compute operation"; let arguments = (ins diff --git a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp index 022b623..d9ac3c5 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsAsm.cpp @@ -42,6 +42,10 @@ static void printCloseDelimiter(OpAsmPrinter& printer, ListDelimiter delimiter) printer << (delimiter == ListDelimiter::Square ? "]" : ")"); } +static bool parseOptionalKeywordAlias(OpAsmParser& parser, StringRef preferred, StringRef legacy) { + return succeeded(parser.parseOptionalKeyword(preferred)) || succeeded(parser.parseOptionalKeyword(legacy)); +} + template static ParseResult parseCompressedRepeatedList(OpAsmParser& parser, ListDelimiter delimiter, @@ -75,51 +79,65 @@ static ParseResult parseCompressedRepeatedList(OpAsmParser& parser, } template -static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl& values) { - if (parser.parseLSquare()) - return failure(); - if (succeeded(parser.parseOptionalRSquare())) +static ParseResult +parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& values) { + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) return success(); while (true) { - int64_t first = 0; - if (parser.parseInteger(first)) - return failure(); + if (succeeded(parser.parseOptionalLParen())) { + SmallVector subgroup; + if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup)) + return failure(); - if (succeeded(parser.parseOptionalKeyword("to"))) { - int64_t last = 0; - if (parser.parseInteger(last) || last < first) - return parser.emitError(parser.getCurrentLocation(), "invalid ascending range"); - - int64_t step = 1; - if (succeeded(parser.parseOptionalKeyword("by"))) { - if (parser.parseInteger(step) || step <= 0) - return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive"); - } int64_t repeatCount = 1; if (succeeded(parser.parseOptionalKeyword("x"))) { if (parser.parseInteger(repeatCount) || repeatCount <= 0) return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); } - if ((last - first) % step != 0) - return parser.emitError(parser.getCurrentLocation(), - "range end must be reachable from start using the given step"); - - for (int64_t value = first; value <= last; value += step) - for (int64_t index = 0; index < repeatCount; ++index) - values.push_back(static_cast(value)); + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + llvm::append_range(values, subgroup); } else { - int64_t repeatCount = 1; - if (succeeded(parser.parseOptionalKeyword("x"))) { - if (parser.parseInteger(repeatCount) || repeatCount <= 0) - return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + int64_t first = 0; + if (parser.parseInteger(first)) + return failure(); + + if (succeeded(parser.parseOptionalKeyword("to"))) { + int64_t last = 0; + if (parser.parseInteger(last) || last < first) + return parser.emitError(parser.getCurrentLocation(), "invalid ascending range"); + + int64_t step = 1; + if (succeeded(parser.parseOptionalKeyword("by"))) { + if (parser.parseInteger(step) || step <= 0) + return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive"); + } + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + if ((last - first) % step != 0) + return parser.emitError(parser.getCurrentLocation(), + "range end must be reachable from start using the given step"); + + for (int64_t value = first; value <= last; value += step) + for (int64_t index = 0; index < repeatCount; ++index) + values.push_back(static_cast(value)); + } + else { + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t index = 0; index < repeatCount; ++index) + values.push_back(static_cast(first)); } - for (int64_t index = 0; index < repeatCount; ++index) - values.push_back(static_cast(first)); } - if (succeeded(parser.parseOptionalRSquare())) + if (succeeded(parseOptionalCloseDelimiter(parser, delimiter))) break; if (parser.parseComma()) return failure(); @@ -128,6 +146,14 @@ static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorIm return success(); } +template +static ParseResult +parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl& values) { + if (parseOpenDelimiter(parser, delimiter)) + return failure(); + return parseCompressedIntegerEntries(parser, delimiter, values); +} + template static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) { for (size_t index = 0; index < entries.size();) { @@ -146,35 +172,51 @@ static void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, Prin } template -static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef values) { - printer << "["; - for (size_t index = 0; index < values.size();) { - if (index != 0) - printer << ", "; - - auto findEqualRunEnd = [&](size_t start) { - size_t end = start + 1; - while (end < values.size() && values[end] == values[start]) - ++end; - return end; +static void printCompressedIntegerSequence(OpAsmPrinter& printer, ArrayRef values, ListDelimiter delimiter) { + struct FlatCompression { + enum class Kind { + Single, + EqualRun, + Progression }; - size_t firstRunEnd = findEqualRunEnd(index); - size_t repeatCount = firstRunEnd - index; + Kind kind = Kind::Single; + size_t covered = 1; + size_t repeatCount = 1; + size_t progressionValueCount = 1; + int64_t step = 1; + IntT firstValue {}; + IntT lastValue {}; + }; + + auto computeFlatCompression = [&](size_t start) { + FlatCompression compression; + compression.firstValue = values[start]; + compression.lastValue = values[start]; + + auto findEqualRunEnd = [&](size_t runStart) { + size_t runEnd = runStart + 1; + while (runEnd < values.size() && values[runEnd] == values[runStart]) + ++runEnd; + return runEnd; + }; + + size_t firstRunEnd = findEqualRunEnd(start); + compression.repeatCount = firstRunEnd - start; size_t progressionEnd = firstRunEnd; int64_t step = 0; - IntT lastValue = values[index]; + IntT lastValue = values[start]; if (firstRunEnd < values.size()) { size_t secondRunEnd = findEqualRunEnd(firstRunEnd); - step = static_cast(values[firstRunEnd]) - static_cast(values[index]); - if (step > 0 && secondRunEnd - firstRunEnd == repeatCount) { + step = static_cast(values[firstRunEnd]) - static_cast(values[start]); + if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) { progressionEnd = secondRunEnd; lastValue = values[firstRunEnd]; size_t currentRunStart = secondRunEnd; while (currentRunStart < values.size()) { size_t currentRunEnd = findEqualRunEnd(currentRunStart); - if (currentRunEnd - currentRunStart != repeatCount) + if (currentRunEnd - currentRunStart != compression.repeatCount) break; if (static_cast(values[currentRunStart]) != static_cast(lastValue) + step) break; @@ -188,27 +230,99 @@ static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef val } } - size_t progressionValueCount = repeatCount == 0 ? 0 : (progressionEnd - index) / repeatCount; - if (progressionEnd > firstRunEnd && progressionValueCount >= 3) { - printer << values[index] << " to " << lastValue; - if (step != 1) - printer << " by " << step; - if (repeatCount > 1) - printer << " x" << repeatCount; - index = progressionEnd; - continue; + compression.covered = 1; + if (progressionEnd > firstRunEnd) { + size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount; + if (progressionValueCount >= 3) { + compression.kind = FlatCompression::Kind::Progression; + compression.covered = progressionEnd - start; + compression.progressionValueCount = progressionValueCount; + compression.step = step; + compression.lastValue = lastValue; + return compression; + } } - if (repeatCount > 1) { - printer << values[index] << " x" << repeatCount; - index = firstRunEnd; - continue; + if (compression.repeatCount > 1) { + compression.kind = FlatCompression::Kind::EqualRun; + compression.covered = compression.repeatCount; + return compression; } - printer << values[index]; - index = firstRunEnd; + return compression; + }; + + auto findRepeatedSublist = [&](size_t start) { + size_t bestLength = 0; + size_t bestRepeatCount = 1; + size_t remaining = values.size() - start; + + for (size_t length = 2; length * 2 <= remaining; ++length) { + size_t repeatCount = 1; + ArrayRef candidate = values.slice(start, length); + while (start + (repeatCount + 1) * length <= values.size() + && llvm::equal(candidate, values.slice(start + repeatCount * length, length))) { + ++repeatCount; + } + + if (repeatCount <= 1) + continue; + + size_t covered = length * repeatCount; + size_t bestCovered = bestLength * bestRepeatCount; + if (covered > bestCovered || (covered == bestCovered && length < bestLength)) { + bestLength = length; + bestRepeatCount = repeatCount; + } + } + + return std::pair(bestLength, bestRepeatCount); + }; + + printOpenDelimiter(printer, delimiter); + for (size_t index = 0; index < values.size();) { + if (index != 0) + printer << ", "; + + FlatCompression flat = computeFlatCompression(index); + auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index); + size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount; + if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) { + printCompressedIntegerSequence(printer, values.slice(index, sublistLength), ListDelimiter::Paren); + printer << " x" << sublistRepeatCount; + index += repeatedSublistCoverage; + continue; + } + switch (flat.kind) { + case FlatCompression::Kind::Progression: + printer << flat.firstValue << " to " << flat.lastValue; + if (flat.step != 1) + printer << " by " << flat.step; + if (flat.repeatCount > 1) + printer << " x" << flat.repeatCount; + index += flat.covered; + break; + case FlatCompression::Kind::EqualRun: + printer << flat.firstValue << " x" << flat.repeatCount; + index += flat.covered; + break; + case FlatCompression::Kind::Single: + printer << flat.firstValue; + index += flat.covered; + break; + } } - printer << "]"; + printCloseDelimiter(printer, delimiter); +} + +template +static ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl& values) { + return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values); +} + +template +static void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef values) { + printCompressedIntegerSequence(printer, values, ListDelimiter::Square); } static void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) { @@ -267,6 +381,165 @@ static void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, List printCloseDelimiter(printer, delimiter); } +static ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser, + SmallVectorImpl& operands); +static ParseResult parseCompressedOperandSequence(OpAsmParser& parser, + SmallVectorImpl& operands); +static ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl& types, bool allowEmpty); + +static bool hasRepeatedTuple(ValueRange values, size_t tupleSize) { + if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0) + return false; + + SmallVector valueVec(values.begin(), values.end()); + ArrayRef tuple(valueVec.data(), tupleSize); + for (size_t index = tupleSize; index < values.size(); index += tupleSize) + if (!llvm::equal(tuple, ArrayRef(valueVec).slice(index, tupleSize))) + return false; + return true; +} + +static bool hasRepeatedTuple(TypeRange types, size_t tupleSize) { + if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0) + return false; + + SmallVector typeVec(types.begin(), types.end()); + ArrayRef tuple(typeVec.data(), tupleSize); + for (size_t index = tupleSize; index < types.size(); index += tupleSize) + if (!llvm::equal(tuple, ArrayRef(typeVec).slice(index, tupleSize))) + return false; + return true; +} + +static void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize) { + printer << "["; + printOpenDelimiter(printer, ListDelimiter::Paren); + for (size_t index = 0; index < tupleSize; ++index) { + if (index != 0) + printer << ", "; + printer.printOperand(values[index]); + } + printCloseDelimiter(printer, ListDelimiter::Paren); + printer << " x" << (values.size() / tupleSize) << "]"; +} + +static void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize) { + printer << "["; + printOpenDelimiter(printer, ListDelimiter::Paren); + for (size_t index = 0; index < tupleSize; ++index) { + if (index != 0) + printer << ", "; + printer.printType(types[index]); + } + printCloseDelimiter(printer, ListDelimiter::Paren); + printer << " x" << (types.size() / tupleSize) << "]"; +} + +static ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser, + SmallVectorImpl& operands) { + if (parser.parseLSquare()) + return failure(); + if (succeeded(parser.parseOptionalRSquare())) + return success(); + + if (succeeded(parser.parseOptionalLParen())) { + SmallVector tupleOperands; + if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen()) + return failure(); + + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + llvm::append_range(operands, tupleOperands); + + while (succeeded(parser.parseOptionalComma())) { + if (parser.parseLParen()) + return failure(); + tupleOperands.clear(); + if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen()) + return failure(); + + repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + llvm::append_range(operands, tupleOperands); + } + return parser.parseRSquare(); + } + + while (true) { + if (parseOneCompressedOperandEntry(parser, operands)) + return failure(); + if (succeeded(parser.parseOptionalRSquare())) + return success(); + if (parser.parseComma()) + return failure(); + } +} + +static ParseResult parseCompressedOrTupleTypeList(OpAsmParser& parser, SmallVectorImpl& types) { + if (parser.parseLSquare()) + return failure(); + if (succeeded(parser.parseOptionalRSquare())) + return success(); + + if (succeeded(parser.parseOptionalLParen())) { + SmallVector tupleTypes; + if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen()) + return failure(); + + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + llvm::append_range(types, tupleTypes); + + while (succeeded(parser.parseOptionalComma())) { + if (parser.parseLParen()) + return failure(); + tupleTypes.clear(); + if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen()) + return failure(); + + repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + llvm::append_range(types, tupleTypes); + } + return parser.parseRSquare(); + } + + while (true) { + Type type; + if (parser.parseType(type)) + return failure(); + + int64_t repeatCount = 1; + if (succeeded(parser.parseOptionalKeyword("x"))) { + if (parser.parseInteger(repeatCount) || repeatCount <= 0) + return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive"); + } + for (int64_t repeat = 0; repeat < repeatCount; ++repeat) + types.push_back(type); + + if (succeeded(parser.parseOptionalRSquare())) + return success(); + if (parser.parseComma()) + return failure(); + } +} + static ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser, OpAsmParser::UnresolvedOperand firstOperand, SmallVectorImpl& operands) { @@ -440,19 +713,88 @@ static IntegerAttr getI32Attr(OpAsmParser& parser, int32_t value) { return parser.getBuilder().getI32IntegerAttr(value); } -static void buildImplicitRegionArgs(OpAsmParser& parser, - ArrayRef inputTypes, - SmallVectorImpl& generatedNames, - SmallVectorImpl& arguments) { - generatedNames.reserve(inputTypes.size()); - arguments.reserve(inputTypes.size()); - for (auto [index, inputType] : llvm::enumerate(inputTypes)) { - generatedNames.push_back("arg" + std::to_string(index + 1)); - OpAsmParser::Argument arg; - arg.ssaName = {parser.getCurrentLocation(), generatedNames.back(), 0}; - arg.type = inputType; - arguments.push_back(arg); +static void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) { + if (block.getNumArguments() == 0) { + printer << "() = ()"; + return; } + + if (block.getNumArguments() == 1) { + printer.printOperand(block.getArgument(0)); + printer << " = "; + printCompressedValueList(printer, operands, ListDelimiter::Paren); + return; + } + + printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren); + printer << " = "; + printCompressedValueList(printer, operands, ListDelimiter::Paren); +} + +static ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser, + OpAsmParser::Argument firstArgument, + SmallVectorImpl& arguments) { + if (succeeded(parser.parseOptionalKeyword("to"))) { + OpAsmParser::Argument lastArgument; + if (parser.parseArgument(lastArgument)) + return failure(); + if (firstArgument.ssaName.name != lastArgument.ssaName.name + || firstArgument.ssaName.number > lastArgument.ssaName.number) { + return parser.emitError(parser.getCurrentLocation(), "invalid argument range"); + } + for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) { + OpAsmParser::Argument argument; + argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number}; + arguments.push_back(argument); + } + return success(); + } + + arguments.push_back(firstArgument); + return success(); +} + +static ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser, + SmallVectorImpl& arguments) { + OpAsmParser::Argument firstArgument; + if (parser.parseArgument(firstArgument)) + return failure(); + return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments); +} + +static void applyArgumentTypes(ArrayRef inputTypes, SmallVectorImpl& arguments) { + for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes)) + argument.type = inputType; +} + +static ParseResult parseArgumentBindings(OpAsmParser& parser, + SmallVectorImpl& arguments, + SmallVectorImpl& operands) { + if (succeeded(parser.parseOptionalLParen())) { + if (succeeded(parser.parseOptionalRParen())) { + if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) + return failure(); + return success(); + } + + OpAsmParser::Argument firstArgument; + if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments)) + return failure(); + while (succeeded(parser.parseOptionalComma())) + if (parseOneCompressedArgumentEntry(parser, arguments)) + return failure(); + if (parser.parseRParen() || parser.parseEqual() + || parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) + return failure(); + return success(); + } + + OpAsmParser::Argument argument; + if (parser.parseArgument(argument) || parser.parseEqual() + || parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) + return failure(); + arguments.push_back(argument); + return success(); } } // namespace @@ -519,8 +861,8 @@ ParseResult SpatExtractRowsOp::parse(OpAsmParser& parser, OperationState& result void SpatConcatOp::print(OpAsmPrinter& printer) { printer << " axis " << getAxis(); - printer << " args = "; - printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); + printer << " "; + printCompressedValueSequence(printer, getInputs()); printer.printOptionalAttrDict((*this)->getAttrs(), {getAxisAttrName().getValue()}); printer << " : "; printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); @@ -537,11 +879,7 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) { if (parser.parseKeyword("axis") || parser.parseInteger(axis)) return failure(); - if (succeeded(parser.parseOptionalKeyword("args"))) { - if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) - return failure(); - } - else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { + if (parseCompressedOperandSequence(parser, inputs)) { return failure(); } @@ -563,14 +901,54 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) { return success(); } +void SpatMapOp::print(OpAsmPrinter& printer) { + printer << " "; + printArgumentBindings(printer, getBody().front(), getInputs()); + printer.printOptionalAttrDict((*this)->getAttrs()); + printer << " : "; + printer.printType(getInputs().front().getType()); + printer << " -> "; + printer.printType(getOutputs().front().getType()); + printer << " "; + printer.printRegion(getBody(), /*printEntryBlockArgs=*/false); +} + +ParseResult SpatMapOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector regionArgs; + SmallVector inputs; + Type inputType; + Type outputType; + + if (parseArgumentBindings(parser, regionArgs, inputs)) + return failure(); + if (inputs.empty()) + return parser.emitError(parser.getCurrentLocation(), "map requires at least one input"); + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(inputType) + || parser.parseArrow() || parser.parseType(outputType)) + return failure(); + + SmallVector inputTypes(inputs.size(), inputType); + SmallVector outputTypes(inputs.size(), outputType); + if (regionArgs.size() != inputs.size()) + return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); + if (parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) + return failure(); + result.addTypes(outputTypes); + + applyArgumentTypes(inputTypes, regionArgs); + Region* body = result.addRegion(); + return parser.parseRegion(*body, regionArgs); +} + void SpatCompute::print(OpAsmPrinter& printer) { printer << " "; printCompressedValueList(printer, getWeights(), ListDelimiter::Square); - printer << " args = "; - printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); + printer << " "; + printArgumentBindings(printer, getBody().front(), getInputs()); if (auto coreIdAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) - printer << " core_id " << coreIdAttr.getInt(); + printer << " coreId " << coreIdAttr.getInt(); printer.printOptionalAttrDict((*this)->getAttrs(), {getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); @@ -587,7 +965,6 @@ void SpatCompute::print(OpAsmPrinter& printer) { ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { SmallVector regionArgs; - SmallVector generatedArgNames; SmallVector weights; SmallVector inputs; SmallVector weightTypes; @@ -598,15 +975,10 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { if (parseCompressedOperandList(parser, ListDelimiter::Square, weights)) return failure(); - if (succeeded(parser.parseOptionalKeyword("args"))) { - if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) - return failure(); - } - else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { + if (parseArgumentBindings(parser, regionArgs, inputs)) return failure(); - } - bool hasCoreId = succeeded(parser.parseOptionalKeyword("core_id")); + bool hasCoreId = parseOptionalKeywordAlias(parser, "coreId", "core_id"); if (hasCoreId && parser.parseInteger(coreId)) return failure(); @@ -622,9 +994,11 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (regionArgs.size() != inputs.size()) + return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); if (hasCoreId && result.attributes.get(onnx_mlir::kCoreIdAttrName)) return parser.emitError(parser.getCurrentLocation(), - "core_id cannot be specified both positionally and in attr-dict"); + "coreId cannot be specified both positionally and in attr-dict"); auto& builder = parser.getBuilder(); result.addAttribute( @@ -639,27 +1013,34 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) { result.addTypes(outputTypes); Region* body = result.addRegion(); - buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs); + applyArgumentTypes(inputTypes, regionArgs); return parser.parseRegion(*body, regionArgs); } void SpatComputeBatch::print(OpAsmPrinter& printer) { printer << " lanes " << getLaneCount() << " "; - printCompressedValueList(printer, getWeights(), ListDelimiter::Square); - printer << " args = "; - printCompressedValueList(printer, getInputs(), ListDelimiter::Paren); + size_t weightsPerLane = getLaneCount() > 0 ? getWeights().size() / static_cast(getLaneCount()) : 0; + if (getLaneCount() > 1 && hasRepeatedTuple(getWeights(), weightsPerLane)) + printValueTupleRun(printer, getWeights(), weightsPerLane); + else + printCompressedValueList(printer, getWeights(), ListDelimiter::Square); + printer << " "; + printArgumentBindings(printer, getBody().front(), getInputs()); - if (auto coreIdsAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdAttrName)) { - printer << " core_ids "; + if (auto coreIdsAttr = (*this)->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) { + printer << " coreIds "; printCompressedIntegerList(printer, coreIdsAttr.asArrayRef()); } printer.printOptionalAttrDict( (*this)->getAttrs(), - {getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdAttrName}); + {getLaneCountAttrName().getValue(), getOperandSegmentSizesAttrName().getValue(), onnx_mlir::kCoreIdsAttrName}); printer << " : "; - printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); + if (getLaneCount() > 1 && hasRepeatedTuple(TypeRange(getWeights()), weightsPerLane)) + printTypeTupleRun(printer, TypeRange(getWeights()), weightsPerLane); + else + printCompressedTypeList(printer, TypeRange(getWeights()), ListDelimiter::Square); printer << " "; printCompressedTypeList(printer, TypeRange(getInputs()), ListDelimiter::Paren); printer << " -> "; @@ -671,7 +1052,6 @@ void SpatComputeBatch::print(OpAsmPrinter& printer) { ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) { int32_t laneCount = 0; SmallVector regionArgs; - SmallVector generatedArgNames; SmallVector weights; SmallVector inputs; SmallVector weightTypes; @@ -682,24 +1062,18 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) if (parser.parseKeyword("lanes") || parser.parseInteger(laneCount)) return failure(); - if (parseCompressedOperandList(parser, ListDelimiter::Square, weights)) + if (parseCompressedOrTupleOperandList(parser, weights)) return failure(); - if (succeeded(parser.parseOptionalKeyword("args"))) { - if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) - return failure(); - } - else if (parseCompressedOperandList(parser, ListDelimiter::Paren, inputs)) { + if (parseArgumentBindings(parser, regionArgs, inputs)) return failure(); - } - bool hasCoreIds = succeeded(parser.parseOptionalKeyword("core_ids")); + bool hasCoreIds = parseOptionalKeywordAlias(parser, "coreIds", "core_ids"); if (hasCoreIds && parseCompressedIntegerList(parser, coreIds)) return failure(); if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() - || parseCompressedRepeatedList( - parser, ListDelimiter::Square, weightTypes, [&](Type& type) { return parser.parseType(type); }) + || parseCompressedOrTupleTypeList(parser, weightTypes) || parseCompressedRepeatedList( parser, ListDelimiter::Paren, inputTypes, [&](Type& type) { return parser.parseType(type); }) || parser.parseArrow() || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/true)) @@ -709,8 +1083,11 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) return parser.emitError(parser.getCurrentLocation(), "number of weights and weight types must match"); if (inputs.size() != inputTypes.size()) return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); - if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdAttrName)) - return parser.emitError(parser.getCurrentLocation(), "core_id cannot be specified both in core_ids and attr-dict"); + if (regionArgs.size() != inputs.size()) + return parser.emitError(parser.getCurrentLocation(), "number of argument bindings and input operands must match"); + if (hasCoreIds && result.attributes.get(onnx_mlir::kCoreIdsAttrName)) + return parser.emitError(parser.getCurrentLocation(), + "coreIds cannot be specified both positionally and in attr-dict"); auto& builder = parser.getBuilder(); result.addAttribute("laneCount", builder.getI32IntegerAttr(laneCount)); @@ -718,7 +1095,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) "operandSegmentSizes", builder.getDenseI32ArrayAttr({static_cast(weights.size()), static_cast(inputs.size())})); if (hasCoreIds) - result.addAttribute(onnx_mlir::kCoreIdAttrName, getDenseI32ArrayAttr(parser, coreIds)); + result.addAttribute(onnx_mlir::kCoreIdsAttrName, getDenseI32ArrayAttr(parser, coreIds)); if (parser.resolveOperands(weights, weightTypes, parser.getCurrentLocation(), result.operands) || parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands)) @@ -726,7 +1103,7 @@ ParseResult SpatComputeBatch::parse(OpAsmParser& parser, OperationState& result) result.addTypes(outputTypes); Region* body = result.addRegion(); - buildImplicitRegionArgs(parser, inputTypes, generatedArgNames, regionArgs); + applyArgumentTypes(inputTypes, regionArgs); return parser.parseRegion(*body, regionArgs); } @@ -867,6 +1244,55 @@ ParseResult SpatChannelSendBatchOp::parse(OpAsmParser& parser, OperationState& r return parser.resolveOperand(input, inputType, result.operands); } +void SpatChannelSendManyBatchOp::print(OpAsmPrinter& printer) { + printer << " "; + printCompressedValueSequence(printer, getInputs()); + printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); + printer.printOptionalAttrDict( + (*this)->getAttrs(), + {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printCompressedTypeSequence(printer, TypeRange(getInputs())); +} + +ParseResult SpatChannelSendManyBatchOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector inputs; + SmallVector inputTypes; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + if (parseCompressedOperandSequence(parser, inputs)) + return failure(); + + bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); + if (hasMetadata) { + if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") + || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") + || parseCompressedIntegerList(parser, targetCoreIds)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedTypeSequence(parser, inputTypes, /*allowEmpty=*/false)) + return failure(); + + if (inputs.size() != inputTypes.size()) + return parser.emitError(parser.getCurrentLocation(), "number of inputs and input types must match"); + if (hasMetadata + && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") + || result.attributes.get("targetCoreIds"))) + return parser.emitError(parser.getCurrentLocation(), + "channel metadata cannot be specified both positionally and in attr-dict"); + if (hasMetadata) { + result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + } + + return parser.resolveOperands(inputs, inputTypes, parser.getCurrentLocation(), result.operands); +} + void SpatChannelReceiveBatchOp::print(OpAsmPrinter& printer) { printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); printer.printOptionalAttrDict( @@ -908,5 +1334,47 @@ ParseResult SpatChannelReceiveBatchOp::parse(OpAsmParser& parser, OperationState return success(); } +void SpatChannelReceiveManyBatchOp::print(OpAsmPrinter& printer) { + printChannelMetadata(printer, getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); + printer.printOptionalAttrDict( + (*this)->getAttrs(), + {getChannelIdsAttrName().getValue(), getSourceCoreIdsAttrName().getValue(), getTargetCoreIdsAttrName().getValue()}); + printer << " : "; + printCompressedTypeSequence(printer, getResultTypes()); +} + +ParseResult SpatChannelReceiveManyBatchOp::parse(OpAsmParser& parser, OperationState& result) { + SmallVector outputTypes; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + bool hasMetadata = succeeded(parser.parseOptionalKeyword("channels")); + if (hasMetadata) { + if (parseCompressedIntegerList(parser, channelIds) || parser.parseKeyword("from") + || parseCompressedIntegerList(parser, sourceCoreIds) || parser.parseKeyword("to") + || parseCompressedIntegerList(parser, targetCoreIds)) + return failure(); + } + + if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() + || parseCompressedTypeSequence(parser, outputTypes, /*allowEmpty=*/false)) + return failure(); + + if (hasMetadata + && (result.attributes.get("channelIds") || result.attributes.get("sourceCoreIds") + || result.attributes.get("targetCoreIds"))) + return parser.emitError(parser.getCurrentLocation(), + "channel metadata cannot be specified both positionally and in attr-dict"); + if (hasMetadata) { + result.addAttribute("channelIds", getDenseI64ArrayAttr(parser, channelIds)); + result.addAttribute("sourceCoreIds", getDenseI32ArrayAttr(parser, sourceCoreIds)); + result.addAttribute("targetCoreIds", getDenseI32ArrayAttr(parser, targetCoreIds)); + } + + result.addTypes(outputTypes); + return success(); +} + } // namespace spatial } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index b4d8657..6dc872b 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -83,13 +83,13 @@ inline LogicalResult mvmOpVerifySize4(SpatWeightedMVMOp* emitter, } static FailureOr> getWeightShapeForWeightedOp(Operation* weightedOp, size_t weightIndex) { - if (auto computeOp = dyn_cast(weightedOp->getParentOp())) + if (auto computeOp = weightedOp->getParentOfType()) return cast(computeOp.getWeights()[weightIndex].getType()).getShape(); - if (auto coreOp = dyn_cast(weightedOp->getParentOp())) + if (auto coreOp = weightedOp->getParentOfType()) return cast(coreOp.getWeights()[weightIndex].getType()).getShape(); - if (auto batchOp = dyn_cast(weightedOp->getParentOp())) { + if (auto batchOp = weightedOp->getParentOfType()) { if (batchOp.getWeights().empty() || weightIndex >= batchOp.getWeights().size()) return failure(); return cast(batchOp.getWeights()[weightIndex].getType()).getShape(); @@ -144,6 +144,23 @@ static LogicalResult verifyBatchChannelSizes(Operation* op, return success(); } +static LogicalResult verifyManyBatchChannelSizes(Operation* op, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + size_t valueCount) { + if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) + return op->emitError("channelIds, sourceCoreIds, and targetCoreIds must have the same length"); + + auto laneCount = getParentBatchLaneCount(op); + if (failed(laneCount)) + return op->emitError("must be nested inside spat.compute_batch"); + if (channelIds.size() != valueCount * static_cast(*laneCount)) + return op->emitError("channel metadata length must match the number of values times parent laneCount"); + + return success(); +} + static LogicalResult verifyBatchBody(Operation* op, Block& block, TypeRange outputTypes, size_t weightsPerLane) { auto yieldOp = dyn_cast_or_null(block.getTerminator()); if (!yieldOp) @@ -306,6 +323,39 @@ LogicalResult SpatConcatOp::verify() { return success(); } +LogicalResult SpatMapOp::verify() { + if (getInputs().empty()) + return emitError("requires at least one input"); + if (getOutputs().size() != getInputs().size()) + return emitError("number of outputs must match number of inputs"); + + Type inputType = getInputs().front().getType(); + for (Value input : getInputs().drop_front()) + if (input.getType() != inputType) + return emitError("all inputs must have the same type"); + + Type outputType = getOutputs().front().getType(); + for (Value output : getOutputs().drop_front()) + if (output.getType() != outputType) + return emitError("all outputs must have the same type"); + + Block& block = getBody().front(); + if (block.getNumArguments() != 1) + return emitError("body must have exactly one block argument"); + if (block.getArgument(0).getType() != inputType) + return emitError("body block argument type must match input type"); + + auto yieldOp = dyn_cast_or_null(block.getTerminator()); + if (!yieldOp) + return emitError("body must terminate with spat.yield"); + if (yieldOp.getNumOperands() != 1) + return emitError("body yield must produce exactly one value"); + if (yieldOp.getOperand(0).getType() != outputType) + return emitError("body yield type must match output type"); + + return success(); +} + LogicalResult SpatCompute::verify() { auto& block = getBody().front(); if (block.mightHaveTerminator()) { @@ -365,10 +415,24 @@ LogicalResult SpatChannelSendBatchOp::verify() { return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); } +LogicalResult SpatChannelSendManyBatchOp::verify() { + if (failed(verifyManyBatchChannelSizes( + getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getInputs().size()))) + return failure(); + return verifyManyChannelTypes(getOperation(), getInputs().getTypes(), "channel_send_many_batch"); +} + LogicalResult SpatChannelReceiveBatchOp::verify() { return verifyBatchChannelSizes(getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds()); } +LogicalResult SpatChannelReceiveManyBatchOp::verify() { + if (failed(verifyManyBatchChannelSizes( + getOperation(), getChannelIds(), getSourceCoreIds(), getTargetCoreIds(), getOutputs().size()))) + return failure(); + return verifyManyChannelTypes(getOperation(), getOperation()->getResultTypes(), "channel_receive_many_batch"); +} + LogicalResult SpatComputeBatch::verify() { int32_t count = getLaneCount(); if (count <= 0) @@ -405,18 +469,18 @@ LogicalResult SpatComputeBatch::verify() { return emitError("all outputs must have the same type"); } - if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdAttrName)) { + if (auto coreIdAttr = (*this)->getAttr(onnx_mlir::kCoreIdsAttrName)) { auto coreIdsAttr = dyn_cast(coreIdAttr); if (!coreIdsAttr) - return emitError("compute_batch core_id attribute must be a dense i32 array"); + return emitError("compute_batch coreIds attribute must be a dense i32 array"); if (coreIdsAttr.size() != laneCountSz) - return emitError("compute_batch core_id array length must match laneCount"); + return emitError("compute_batch coreIds array length must match laneCount"); if (llvm::any_of(coreIdsAttr.asArrayRef(), [](int32_t coreId) { return coreId <= 0; })) - return emitError("compute_batch core_id values must be positive"); + return emitError("compute_batch coreIds values must be positive"); llvm::SmallDenseSet seenCoreIds; for (int32_t coreId : coreIdsAttr.asArrayRef()) if (!seenCoreIds.insert(coreId).second) - return emitError("compute_batch core_id values must be distinct"); + return emitError("compute_batch coreIds values must be distinct"); } Block& block = getBody().front(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index 7cb69b0..df0422d 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -1,5 +1,7 @@ #include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Location.h" @@ -35,6 +37,7 @@ #include #include "DCPGraph/DCPAnalysis.hpp" +#include "RegularOpCompaction.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" @@ -147,7 +150,7 @@ static SmallVector getMaterializedBatchCoreIds(size_t startCpu, size_t } static SmallVector getBatchCoreIds(Operation* op, size_t laneCount) { - if (auto coreIdsAttr = op->getAttrOfType(onnx_mlir::kCoreIdAttrName)) + if (auto coreIdsAttr = op->getAttrOfType(onnx_mlir::kCoreIdsAttrName)) return SmallVector(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); if (auto coreIdAttr = op->getAttrOfType(onnx_mlir::kCoreIdAttrName)) return SmallVector(laneCount, static_cast(coreIdAttr.getInt())); @@ -304,7 +307,7 @@ static void sinkChannelsIntoBatchComputes(func::FuncOp funcOp, SmallVector coreIds = getBatchCoreIds(batch, static_cast(batch.getLaneCount())); if (!coreIds.empty()) - newBatch->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); + newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); auto* newBlock = rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange {}, ArrayRef {}); @@ -548,141 +551,6 @@ void sinkChannelsIntoComputes(func::FuncOp funcOp, int64_t& nextChannelId) { } } -static void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { - IRRewriter rewriter(funcOp.getContext()); - - for (auto compute : funcOp.getOps()) { - Block& block = compute.getBody().front(); - for (auto it = block.begin(); it != block.end();) { - auto receiveOp = dyn_cast(&*it); - if (receiveOp) { - SmallVector run; - Type outputType = receiveOp.getOutput().getType(); - auto runIt = it; - while (runIt != block.end()) { - auto current = dyn_cast(&*runIt); - if (!current || current.getOutput().getType() != outputType) - break; - run.push_back(current); - ++runIt; - } - - if (run.size() > 1) { - struct ReceiveEntry { - spatial::SpatChannelReceiveOp op; - size_t originalIndex = 0; - uint32_t sourceCoreId = 0; - uint32_t targetCoreId = 0; - uint64_t channelId = 0; - }; - SmallVector sortedEntries; - sortedEntries.reserve(run.size()); - for (auto [originalIndex, op] : llvm::enumerate(run)) - sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); - llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) { - return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) - < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); - }); - - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - SmallVector outputTypes; - channelIds.reserve(sortedEntries.size()); - sourceCoreIds.reserve(sortedEntries.size()); - targetCoreIds.reserve(sortedEntries.size()); - outputTypes.reserve(sortedEntries.size()); - for (ReceiveEntry& entry : sortedEntries) { - (void) entry; - channelIds.push_back(nextChannelId++); - sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); - targetCoreIds.push_back(static_cast(entry.targetCoreId)); - outputTypes.push_back(entry.op.getOutput().getType()); - } - - rewriter.setInsertionPoint(run.front()); - auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter, - run.front().getLoc(), - TypeRange(outputTypes), - rewriter.getDenseI64ArrayAttr(channelIds), - rewriter.getDenseI32ArrayAttr(sourceCoreIds), - rewriter.getDenseI32ArrayAttr(targetCoreIds)); - for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries)) - entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex)); - for (auto op : run) - rewriter.eraseOp(op); - - it = compactReceive->getIterator(); - ++it; - continue; - } - } - - auto sendOp = dyn_cast(&*it); - if (sendOp) { - SmallVector run; - Type inputType = sendOp.getInput().getType(); - auto runIt = it; - while (runIt != block.end()) { - auto current = dyn_cast(&*runIt); - if (!current || current.getInput().getType() != inputType) - break; - run.push_back(current); - ++runIt; - } - - if (run.size() > 1) { - struct SendEntry { - spatial::SpatChannelSendOp op; - uint32_t sourceCoreId = 0; - uint32_t targetCoreId = 0; - uint64_t channelId = 0; - }; - SmallVector sortedEntries; - sortedEntries.reserve(run.size()); - for (auto op : run) - sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); - llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) { - return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) - < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); - }); - - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - SmallVector inputs; - channelIds.reserve(sortedEntries.size()); - sourceCoreIds.reserve(sortedEntries.size()); - targetCoreIds.reserve(sortedEntries.size()); - inputs.reserve(sortedEntries.size()); - for (SendEntry& entry : sortedEntries) { - (void) entry; - channelIds.push_back(nextChannelId++); - sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); - targetCoreIds.push_back(static_cast(entry.targetCoreId)); - inputs.push_back(entry.op.getInput()); - } - - rewriter.setInsertionPoint(run.front()); - spatial::SpatChannelSendManyOp::create(rewriter, - run.front().getLoc(), - rewriter.getDenseI64ArrayAttr(channelIds), - rewriter.getDenseI32ArrayAttr(sourceCoreIds), - rewriter.getDenseI32ArrayAttr(targetCoreIds), - ValueRange(inputs)); - for (auto op : run) - rewriter.eraseOp(op); - - it = runIt; - continue; - } - } - - ++it; - } - } -} - void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) { IRRewriter rewriter(funcOp.getContext()); SmallVector computes(funcOp.getOps()); @@ -755,7 +623,7 @@ void rebatchEquivalentComputes(func::FuncOp funcOp, int64_t& nextChannelId) { rebatched.getProperties().setOperandSegmentSizes( {static_cast(weights.size()), static_cast(inputs.size())}); if (haveAllCoreIds) - rebatched->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); + rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); SmallVector blockArgTypes; SmallVector blockArgLocs; @@ -1879,6 +1747,9 @@ public: rebatchEquivalentComputes(func, nextChannelId); compactScalarChannelRuns(func, nextChannelId); + compactBatchChannelRuns(func); + compactRegularOpRuns(func); + compactRowWiseWvmmRuns(func); if (!sortTopologically(&func.getBody().front())) { func.emitOpError("failed to topologically order merged Spatial IR"); signalPassFailure(); @@ -2049,7 +1920,7 @@ private: rewriter.getI32IntegerAttr(static_cast(laneCount)), ValueRange(weights), ValueRange(inputs)); - rebatched->setAttr(onnx_mlir::kCoreIdAttrName, + rebatched->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(getMaterializedBatchCoreIds(currentCpu, laneCount))); SmallVector blockArgTypes; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp new file mode 100644 index 0000000..bd50aa2 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.cpp @@ -0,0 +1,577 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" + +#include + +#include "RegularOpCompaction.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace { + +enum class RegularStepKind { + Wvmm, + VAddLhs, + VAddRhs, +}; + +struct RegularStep { + RegularStepKind kind; + int32_t weightIndex = 0; + Value invariantOperand; + Type resultType; +}; + +struct RegularChunk { + Operation* startOp = nullptr; + SmallVector ops; + SmallVector steps; + Value input; + Value output; +}; + +static bool areEquivalentRegularSteps(const RegularStep& lhs, const RegularStep& rhs) { + return lhs.kind == rhs.kind && lhs.weightIndex == rhs.weightIndex && lhs.invariantOperand == rhs.invariantOperand + && lhs.resultType == rhs.resultType; +} + +static bool areEquivalentRegularChunks(const RegularChunk& lhs, const RegularChunk& rhs) { + if (lhs.input.getType() != rhs.input.getType() || lhs.output.getType() != rhs.output.getType() + || lhs.steps.size() != rhs.steps.size()) { + return false; + } + + return llvm::all_of(llvm::zip_equal(lhs.steps, rhs.steps), + [](auto pair) { return areEquivalentRegularSteps(std::get<0>(pair), std::get<1>(pair)); }); +} + +static FailureOr analyzeRegularChunk(spatial::SpatWeightedVMMOp startOp) { + RegularChunk chunk; + chunk.startOp = startOp.getOperation(); + chunk.input = startOp.getInput(); + chunk.output = startOp.getOutput(); + chunk.ops.push_back(startOp.getOperation()); + chunk.steps.push_back( + {RegularStepKind::Wvmm, static_cast(startOp.getWeightIndex()), Value(), startOp.getOutput().getType()}); + + Value currentValue = startOp.getOutput(); + while (currentValue.hasOneUse()) { + Operation* user = *currentValue.getUsers().begin(); + if (user->getBlock() != startOp->getBlock()) + break; + + auto vaddOp = dyn_cast(user); + if (!vaddOp) + break; + + if (vaddOp.getLhs() == currentValue) + chunk.steps.push_back({RegularStepKind::VAddLhs, 0, vaddOp.getRhs(), vaddOp.getOutput().getType()}); + else if (vaddOp.getRhs() == currentValue) + chunk.steps.push_back({RegularStepKind::VAddRhs, 0, vaddOp.getLhs(), vaddOp.getOutput().getType()}); + else + break; + + chunk.ops.push_back(vaddOp); + chunk.output = vaddOp.getOutput(); + currentValue = vaddOp.getOutput(); + } + + return chunk; +} + +static void buildRegularMapBody(spatial::SpatMapOp mapOp, const RegularChunk& anchorChunk, IRRewriter& rewriter) { + auto* block = rewriter.createBlock( + &mapOp.getBody(), mapOp.getBody().end(), TypeRange {anchorChunk.input.getType()}, {anchorChunk.startOp->getLoc()}); + rewriter.setInsertionPointToEnd(block); + + IRMapping mapping; + mapping.map(anchorChunk.input, block->getArgument(0)); + + for (Operation* op : anchorChunk.ops) { + Operation* cloned = rewriter.clone(*op, mapping); + for (auto [oldResult, newResult] : llvm::zip(op->getResults(), cloned->getResults())) + mapping.map(oldResult, newResult); + } + + spatial::SpatYieldOp::create( + rewriter, anchorChunk.startOp->getLoc(), ValueRange {mapping.lookup(anchorChunk.output)}); +} + +static void compactRegularChunkRun(IRRewriter& rewriter, ArrayRef run) { + assert(!run.empty() && "expected a non-empty regular chunk run"); + const RegularChunk& anchorChunk = run.front(); + + SmallVector inputs; + SmallVector outputTypes; + inputs.reserve(run.size()); + outputTypes.reserve(run.size()); + for (const RegularChunk& chunk : run) { + inputs.push_back(chunk.input); + outputTypes.push_back(chunk.output.getType()); + } + + rewriter.setInsertionPoint(anchorChunk.startOp); + auto mapOp = + spatial::SpatMapOp::create(rewriter, anchorChunk.startOp->getLoc(), TypeRange(outputTypes), ValueRange(inputs)); + buildRegularMapBody(mapOp, anchorChunk, rewriter); + + for (auto [index, chunk] : llvm::enumerate(run)) { + Value output = chunk.output; + output.replaceAllUsesWith(mapOp.getResult(index)); + } + + SmallVector opsToErase; + for (const RegularChunk& chunk : run) + llvm::append_range(opsToErase, chunk.ops); + for (Operation* op : llvm::reverse(opsToErase)) + rewriter.eraseOp(op); +} + +} // namespace + +void compactScalarChannelRuns(func::FuncOp funcOp, int64_t& nextChannelId) { + IRRewriter rewriter(funcOp.getContext()); + + for (auto compute : funcOp.getOps()) { + Block& block = compute.getBody().front(); + for (auto it = block.begin(); it != block.end();) { + auto receiveOp = dyn_cast(&*it); + if (receiveOp) { + SmallVector run; + Type outputType = receiveOp.getOutput().getType(); + auto runIt = it; + while (runIt != block.end()) { + auto current = dyn_cast(&*runIt); + if (!current || current.getOutput().getType() != outputType) + break; + run.push_back(current); + ++runIt; + } + + if (run.size() > 1) { + struct ReceiveEntry { + spatial::SpatChannelReceiveOp op; + size_t originalIndex = 0; + uint32_t sourceCoreId = 0; + uint32_t targetCoreId = 0; + uint64_t channelId = 0; + }; + SmallVector sortedEntries; + sortedEntries.reserve(run.size()); + for (auto [originalIndex, op] : llvm::enumerate(run)) + sortedEntries.push_back({op, originalIndex, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); + llvm::stable_sort(sortedEntries, [](const ReceiveEntry& lhs, const ReceiveEntry& rhs) { + return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) + < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); + }); + + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + SmallVector outputTypes; + channelIds.reserve(sortedEntries.size()); + sourceCoreIds.reserve(sortedEntries.size()); + targetCoreIds.reserve(sortedEntries.size()); + outputTypes.reserve(sortedEntries.size()); + for (ReceiveEntry& entry : sortedEntries) { + (void) entry; + channelIds.push_back(nextChannelId++); + sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); + targetCoreIds.push_back(static_cast(entry.targetCoreId)); + outputTypes.push_back(entry.op.getOutput().getType()); + } + + rewriter.setInsertionPoint(run.front()); + auto compactReceive = spatial::SpatChannelReceiveManyOp::create(rewriter, + run.front().getLoc(), + TypeRange(outputTypes), + rewriter.getDenseI64ArrayAttr(channelIds), + rewriter.getDenseI32ArrayAttr(sourceCoreIds), + rewriter.getDenseI32ArrayAttr(targetCoreIds)); + for (auto [sortedIndex, entry] : llvm::enumerate(sortedEntries)) + entry.op.getOutput().replaceAllUsesWith(compactReceive.getResult(sortedIndex)); + for (auto op : run) + rewriter.eraseOp(op); + + it = compactReceive->getIterator(); + ++it; + continue; + } + } + + auto sendOp = dyn_cast(&*it); + if (sendOp) { + SmallVector run; + Type inputType = sendOp.getInput().getType(); + auto runIt = it; + while (runIt != block.end()) { + auto current = dyn_cast(&*runIt); + if (!current || current.getInput().getType() != inputType) + break; + run.push_back(current); + ++runIt; + } + + if (run.size() > 1) { + struct SendEntry { + spatial::SpatChannelSendOp op; + uint32_t sourceCoreId = 0; + uint32_t targetCoreId = 0; + uint64_t channelId = 0; + }; + SmallVector sortedEntries; + sortedEntries.reserve(run.size()); + for (auto op : run) + sortedEntries.push_back({op, op.getSourceCoreId(), op.getTargetCoreId(), op.getChannelId()}); + llvm::stable_sort(sortedEntries, [](const SendEntry& lhs, const SendEntry& rhs) { + return std::tuple(lhs.sourceCoreId, lhs.targetCoreId, lhs.channelId) + < std::tuple(rhs.sourceCoreId, rhs.targetCoreId, rhs.channelId); + }); + + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + SmallVector inputs; + channelIds.reserve(sortedEntries.size()); + sourceCoreIds.reserve(sortedEntries.size()); + targetCoreIds.reserve(sortedEntries.size()); + inputs.reserve(sortedEntries.size()); + for (SendEntry& entry : sortedEntries) { + (void) entry; + channelIds.push_back(nextChannelId++); + sourceCoreIds.push_back(static_cast(entry.sourceCoreId)); + targetCoreIds.push_back(static_cast(entry.targetCoreId)); + inputs.push_back(entry.op.getInput()); + } + + rewriter.setInsertionPoint(run.front()); + spatial::SpatChannelSendManyOp::create(rewriter, + run.front().getLoc(), + rewriter.getDenseI64ArrayAttr(channelIds), + rewriter.getDenseI32ArrayAttr(sourceCoreIds), + rewriter.getDenseI32ArrayAttr(targetCoreIds), + ValueRange(inputs)); + for (auto op : run) + rewriter.eraseOp(op); + + it = runIt; + continue; + } + } + + ++it; + } + } +} + +void compactBatchChannelRuns(func::FuncOp funcOp) { + IRRewriter rewriter(funcOp.getContext()); + + for (auto batch : funcOp.getOps()) { + Block& block = batch.getBody().front(); + for (auto it = block.begin(); it != block.end();) { + auto receiveOp = dyn_cast(&*it); + if (receiveOp) { + SmallVector run; + Type outputType = receiveOp.getOutput().getType(); + auto runIt = it; + while (runIt != block.end()) { + auto current = dyn_cast(&*runIt); + if (!current || current.getOutput().getType() != outputType) + break; + run.push_back(current); + ++runIt; + } + + if (run.size() > 1) { + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + SmallVector outputTypes; + outputTypes.reserve(run.size()); + for (auto op : run) { + llvm::append_range(channelIds, op.getChannelIds()); + llvm::append_range(sourceCoreIds, op.getSourceCoreIds()); + llvm::append_range(targetCoreIds, op.getTargetCoreIds()); + outputTypes.push_back(op.getOutput().getType()); + } + + rewriter.setInsertionPoint(run.front()); + auto compactReceive = + spatial::SpatChannelReceiveManyBatchOp::create(rewriter, + run.front().getLoc(), + TypeRange(outputTypes), + rewriter.getDenseI64ArrayAttr(channelIds), + rewriter.getDenseI32ArrayAttr(sourceCoreIds), + rewriter.getDenseI32ArrayAttr(targetCoreIds)); + for (auto [index, op] : llvm::enumerate(run)) + op.getOutput().replaceAllUsesWith(compactReceive.getResult(index)); + for (auto op : run) + rewriter.eraseOp(op); + + it = compactReceive->getIterator(); + ++it; + continue; + } + } + + auto sendOp = dyn_cast(&*it); + if (sendOp) { + SmallVector run; + Type inputType = sendOp.getInput().getType(); + auto runIt = it; + while (runIt != block.end()) { + auto current = dyn_cast(&*runIt); + if (!current || current.getInput().getType() != inputType) + break; + run.push_back(current); + ++runIt; + } + + if (run.size() > 1) { + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + SmallVector inputs; + inputs.reserve(run.size()); + for (auto op : run) { + llvm::append_range(channelIds, op.getChannelIds()); + llvm::append_range(sourceCoreIds, op.getSourceCoreIds()); + llvm::append_range(targetCoreIds, op.getTargetCoreIds()); + inputs.push_back(op.getInput()); + } + + rewriter.setInsertionPoint(run.front()); + spatial::SpatChannelSendManyBatchOp::create(rewriter, + run.front().getLoc(), + rewriter.getDenseI64ArrayAttr(channelIds), + rewriter.getDenseI32ArrayAttr(sourceCoreIds), + rewriter.getDenseI32ArrayAttr(targetCoreIds), + ValueRange(inputs)); + for (auto op : run) + rewriter.eraseOp(op); + + it = runIt; + continue; + } + } + + ++it; + } + } +} + +void compactRegularOpRuns(func::FuncOp funcOp) { + IRRewriter rewriter(funcOp.getContext()); + + auto compactInBlock = [&](Block& block) { + for (auto it = block.begin(); it != block.end();) { + auto startOp = dyn_cast(&*it); + if (!startOp) { + ++it; + continue; + } + + auto anchorChunk = analyzeRegularChunk(startOp); + if (failed(anchorChunk)) { + ++it; + continue; + } + + SmallVector run {*anchorChunk}; + auto runIt = std::next(it, static_cast(anchorChunk->ops.size())); + while (runIt != block.end()) { + auto candidateStart = dyn_cast(&*runIt); + if (!candidateStart) + break; + + auto candidateChunk = analyzeRegularChunk(candidateStart); + if (failed(candidateChunk) || !areEquivalentRegularChunks(*anchorChunk, *candidateChunk)) + break; + + run.push_back(*candidateChunk); + runIt = std::next(runIt, static_cast(candidateChunk->ops.size())); + } + + if (run.size() <= 1) { + ++it; + continue; + } + + compactRegularChunkRun(rewriter, run); + it = runIt; + } + }; + + for (auto compute : funcOp.getOps()) + compactInBlock(compute.getBody().front()); + for (auto batch : funcOp.getOps()) + compactInBlock(batch.getBody().front()); +} + +void compactRowWiseWvmmRuns(func::FuncOp funcOp) { + IRRewriter rewriter(funcOp.getContext()); + + for (auto compute : funcOp.getOps()) { + Block& block = compute.getBody().front(); + for (auto it = block.begin(); it != block.end();) { + auto wvmmOp = dyn_cast(&*it); + if (!wvmmOp) { + ++it; + continue; + } + + auto extractRowsOp = wvmmOp.getInput().getDefiningOp(); + auto rowResult = dyn_cast(wvmmOp.getInput()); + auto outputType = dyn_cast(wvmmOp.getOutput().getType()); + if (!extractRowsOp || !rowResult || rowResult.getOwner() != extractRowsOp || !outputType + || !outputType.hasStaticShape() || outputType.getRank() != 2 || outputType.getShape()[0] != 1) { + ++it; + continue; + } + + SmallVector run; + auto runIt = it; + int64_t expectedRow = static_cast(rowResult.getResultNumber()); + while (runIt != block.end()) { + auto current = dyn_cast(&*runIt); + if (!current || current.getWeightIndex() != wvmmOp.getWeightIndex() + || current.getInput().getDefiningOp() != extractRowsOp + || current.getInput().getType() != wvmmOp.getInput().getType() + || current.getOutput().getType() != wvmmOp.getOutput().getType()) { + break; + } + + auto currentRow = dyn_cast(current.getInput()); + if (!currentRow || currentRow.getResultNumber() != static_cast(expectedRow)) + break; + + run.push_back(current); + ++expectedRow; + ++runIt; + } + + if (run.size() <= 1) { + ++it; + continue; + } + + if (!run.front().getOutput().hasOneUse()) { + ++it; + continue; + } + auto concatUse = run.front().getOutput().getUses().begin(); + auto concatOp = dyn_cast(concatUse->getOwner()); + if (!concatOp) { + ++it; + continue; + } + + unsigned concatStartIndex = concatUse->getOperandNumber(); + bool validConcatRun = true; + for (auto [index, op] : llvm::enumerate(run)) { + if (!op.getOutput().hasOneUse()) { + validConcatRun = false; + break; + } + OpOperand& use = *op.getOutput().getUses().begin(); + if (use.getOwner() != concatOp || use.getOperandNumber() != concatStartIndex + index) { + validConcatRun = false; + break; + } + } + if (!validConcatRun) { + ++it; + continue; + } + + auto inputType = dyn_cast(wvmmOp.getInput().getType()); + auto sourceType = dyn_cast(extractRowsOp.getInput().getType()); + if (!inputType || !sourceType || !inputType.hasStaticShape() || !sourceType.hasStaticShape()) { + ++it; + continue; + } + + int64_t inputCols = inputType.getShape()[1]; + int64_t outputCols = outputType.getShape()[1]; + if (ShapedType::isDynamic(inputCols) || ShapedType::isDynamic(outputCols)) { + ++it; + continue; + } + + int64_t firstRow = static_cast(rowResult.getResultNumber()); + int64_t runLength = static_cast(run.size()); + auto packedType = RankedTensorType::get({runLength, outputCols}, outputType.getElementType()); + + rewriter.setInsertionPoint(run.front()); + auto zero = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 0); + auto upper = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), runLength); + auto step = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), 1); + auto packedInit = + tensor::EmptyOp::create(rewriter, run.front().getLoc(), packedType.getShape(), packedType.getElementType()); + auto loop = + scf::ForOp::create(rewriter, run.front().getLoc(), zero, upper, step, ValueRange {packedInit.getResult()}); + + { + OpBuilder::InsertionGuard guard(rewriter); + Block* loopBlock = loop.getBody(); + rewriter.setInsertionPointToStart(loopBlock); + Value iv = loopBlock->getArgument(0); + Value acc = loopBlock->getArgument(1); + + Value sourceRow = iv; + if (firstRow != 0) { + auto firstRowValue = arith::ConstantIndexOp::create(rewriter, run.front().getLoc(), firstRow); + sourceRow = arith::AddIOp::create(rewriter, run.front().getLoc(), iv, firstRowValue); + } + + SmallVector extractOffsets = {sourceRow, rewriter.getIndexAttr(0)}; + SmallVector extractSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(inputCols)}; + SmallVector extractStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + auto extractedRow = tensor::ExtractSliceOp::create(rewriter, + run.front().getLoc(), + inputType, + extractRowsOp.getInput(), + extractOffsets, + extractSizes, + extractStrides); + auto loopWvmm = spatial::SpatWeightedVMMOp::create( + rewriter, run.front().getLoc(), outputType, wvmmOp.getWeightIndex(), extractedRow.getResult()); + + SmallVector insertOffsets = {iv, rewriter.getIndexAttr(0)}; + SmallVector insertSizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outputCols)}; + SmallVector insertStrides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; + auto inserted = tensor::InsertSliceOp::create( + rewriter, run.front().getLoc(), loopWvmm.getResult(), acc, insertOffsets, insertSizes, insertStrides); + scf::YieldOp::create(rewriter, run.front().getLoc(), inserted.getResult()); + } + + SmallVector newConcatInputs; + newConcatInputs.reserve(concatOp.getInputs().size() - run.size() + 1); + for (auto [operandIndex, operand] : llvm::enumerate(concatOp.getInputs())) { + if (operandIndex == concatStartIndex) + newConcatInputs.push_back(loop.getResult(0)); + if (operandIndex < concatStartIndex || operandIndex >= concatStartIndex + run.size()) + newConcatInputs.push_back(operand); + } + rewriter.modifyOpInPlace(concatOp, [&] { concatOp->setOperands(newConcatInputs); }); + for (auto op : run) + rewriter.eraseOp(op); + + it = loop->getIterator(); + ++it; + } + } +} + +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.hpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.hpp new file mode 100644 index 0000000..08b7d1e --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/RegularOpCompaction.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "mlir/Dialect/Func/IR/FuncOps.h" + +#include + +namespace onnx_mlir { + +void compactScalarChannelRuns(mlir::func::FuncOp funcOp, int64_t& nextChannelId); +void compactBatchChannelRuns(mlir::func::FuncOp funcOp); +void compactRegularOpRuns(mlir::func::FuncOp funcOp); +void compactRowWiseWvmmRuns(mlir::func::FuncOp funcOp); + +} // namespace onnx_mlir