diff --git a/backend-simulators/pim/pim-simulator/src/lib/pimcore.rs b/backend-simulators/pim/pim-simulator/src/lib/pimcore.rs index ddb40ee..d33bf34 100644 --- a/backend-simulators/pim/pim-simulator/src/lib/pimcore.rs +++ b/backend-simulators/pim/pim-simulator/src/lib/pimcore.rs @@ -299,10 +299,11 @@ fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option>() .join(" -> "); @@ -311,19 +312,19 @@ fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option>(); - let cycle_msg = format!("{} -> {}", cycle_str, waiting_for); + let cycle_msg = format!("{} -> {}", cycle_str, waiting_for - 1); let states_msg = cycle .iter() .filter_map(|core| { states.get(core).map(|state| match state { CoreState::SendingTo(target, size) => { - format!("core {} send {}B -> {}", core, size, target) + format!("core {} send {}B -> {}", core - 1, size, target - 1) } CoreState::ReceivingFrom(source, size) => { - format!("core {} recv {}B <- {}", core, size, source) + format!("core {} recv {}B <- {}", core - 1, size, source - 1) } - CoreState::Working => format!("core {} working", core), - CoreState::Halted => format!("core {} halted", core), + CoreState::Working => format!("core {} working", core - 1), + CoreState::Halted => format!("core {} halted", core - 1), }) }) .collect::>() diff --git a/src/PIM/Compiler/PimBatchEmission.cpp b/src/PIM/Compiler/PimBatchEmission.cpp index 6bbf0d1..20b5262 100644 --- a/src/PIM/Compiler/PimBatchEmission.cpp +++ b/src/PIM/Compiler/PimBatchEmission.cpp @@ -28,23 +28,47 @@ static SmallVector getLaneChunkCoreIds(ArrayRef coreIds, size_ return laneCoreIds; } +static Value getOrCloneCapturedValue(OpBuilder& builder, Block& oldBlock, Value value, IRMapping& mapper) { + if (Value mapped = mapper.lookupOrNull(value)) + return mapped; + + if (auto blockArgument = dyn_cast(value)) { + assert(blockArgument.getOwner() != &oldBlock && "expected block argument to be mapped before cloning"); + assert(false && "unexpected captured block argument while scalarizing pim.core_batch"); + } + + Operation* definingOp = value.getDefiningOp(); + assert(definingOp && "expected captured value to be defined by an operation"); + assert(definingOp->getBlock() != &oldBlock && "expected in-block value to be mapped before cloning"); + + for (Value operand : definingOp->getOperands()) + (void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper); + + Operation* cloned = builder.clone(*definingOp, mapper); + for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults())) + mapper.map(originalResult, clonedResult); + return mapper.lookup(value); +} + static void cloneScalarizedLaneBody(OpBuilder& builder, pim::PimCoreBatchOp coreBatchOp, unsigned lane, OperationFolder& constantFolder) { Block& oldBlock = coreBatchOp.getBody().front(); + Operation* anchorOp = builder.getInsertionBlock()->getParentOp(); size_t laneCount = static_cast(coreBatchOp.getLaneCount()); size_t weightCount = coreBatchOp.getWeights().size(); IRMapping mapper; for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) { if (blockArg.getType().isIndex()) { - mapper.map(blockArg, getOrCreateHostIndexConstant(coreBatchOp, static_cast(lane), constantFolder)); + mapper.map(blockArg, getOrCreateHostIndexConstant(anchorOp, static_cast(lane), constantFolder)); continue; } if (argIndex <= weightCount) { - mapper.map(blockArg, coreBatchOp.getWeights()[argIndex - 1]); + auto scalarCoreOp = cast(anchorOp); + mapper.map(blockArg, scalarCoreOp.getWeightArgument(argIndex - 1)); continue; } @@ -57,8 +81,10 @@ static void cloneScalarizedLaneBody(OpBuilder& builder, if (isa(op)) continue; + for (Value operand : op.getOperands()) + (void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper); + if (auto sendBatchOp = dyn_cast(op)) { - Operation* anchorOp = builder.getInsertionBlock()->getParentOp(); pim::PimSendOp::create( builder, sendBatchOp.getLoc(), @@ -78,7 +104,6 @@ static void cloneScalarizedLaneBody(OpBuilder& builder, } if (auto receiveBatchOp = dyn_cast(op)) { - Operation* anchorOp = builder.getInsertionBlock()->getParentOp(); auto scalarReceive = pim::PimReceiveOp::create( builder, receiveBatchOp.getLoc(), @@ -106,8 +131,8 @@ static void cloneScalarizedLaneBody(OpBuilder& builder, builder, memcpBatchOp.getLoc(), memcpBatchOp.getOutput().getType(), - getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder), - getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getHostSourceOffset(), constantFolder), + getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder), + getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getHostSourceOffset(), constantFolder), mapper.lookup(memcpBatchOp.getDeviceTarget()), mapper.lookup(memcpBatchOp.getHostSource()), memcpBatchOp.getSizeAttr()); @@ -141,7 +166,16 @@ LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp, auto scalarCore = pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId)); - Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end()); + SmallVector weightTypes; + SmallVector weightLocs; + weightTypes.reserve(weights.size()); + weightLocs.reserve(weights.size()); + for (Value weight : weights) { + weightTypes.push_back(weight.getType()); + weightLocs.push_back(weight.getLoc()); + } + Block* block = + builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end(), TypeRange(weightTypes), weightLocs); builder.setInsertionPointToEnd(block); for (unsigned lane : lanes) cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder); diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index bb3f1da..397874b 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -1,6 +1,8 @@ +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" @@ -97,20 +99,73 @@ static LogicalResult lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveT return success(); } +static FailureOr getDirectReturnOperandIndex(OpResult result) { + if (!result.hasOneUse()) + return failure(); + + auto returnOp = dyn_cast(*result.getUsers().begin()); + if (!returnOp) + return failure(); + return result.getUses().begin()->getOperandNumber(); +} + +static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) { + if (scale == 1) + return base; + + auto scaleValue = arith::ConstantIndexOp::create(rewriter, loc, scale).getResult(); + return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult(); +} + +static Value createHostTargetOffset(IRRewriter& rewriter, + tensor::ParallelInsertSliceOp insertSlice, + ShapedType destinationType, + IRMapping& mapper) { + int64_t elementBytes = destinationType.getElementTypeBitWidth() / 8; + SmallVector strides(destinationType.getRank(), 1); + ArrayRef shape = destinationType.getShape(); + for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim) + strides[dim] = strides[dim + 1] * shape[dim + 1]; + + Value totalOffset; + Location loc = insertSlice.getLoc(); + for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) { + int64_t scale = strides[dim] * elementBytes; + Value scaledOffset; + if (auto attr = dyn_cast(offset)) { + auto intAttr = dyn_cast(attr); + assert(intAttr && "expected integer offset attribute"); + scaledOffset = arith::ConstantIndexOp::create(rewriter, loc, intAttr.getInt() * scale).getResult(); + } + else { + scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast(offset)), scale); + } + + totalOffset = totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult() + : scaledOffset; + } + + if (!totalOffset) + totalOffset = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); + return totalOffset; +} + } // namespace LogicalResult lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) { Location loc = computeBatchOp.getLoc(); Block& oldBlock = computeBatchOp.getBody().front(); - if (computeBatchOp.getNumResults() != 0) - return computeBatchOp.emitOpError( - "batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results; " - "materialize explicit communication before lowering to PIM"); - auto oldYield = dyn_cast(oldBlock.getTerminator()); - if (!oldYield || oldYield.getNumOperands() != 0) - return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield"); + auto inParallelOp = dyn_cast(oldBlock.getTerminator()); + if (computeBatchOp.getNumResults() == 0) { + if (!oldYield || oldYield.getNumOperands() != 0) + return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield"); + } + else if (!inParallelOp) { + return computeBatchOp.emitOpError( + "resultful compute_batch lowering currently requires a spat.in_parallel terminator"); + } SmallVector coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId); SmallVector batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end()); @@ -128,9 +183,24 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& {static_cast(batchWeights.size()), static_cast(batchInputs.size())}); coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); + SmallVector hostOutputTensors; + if (computeBatchOp.getNumResults() != 0) { + hostOutputTensors.resize(computeBatchOp.getNumResults()); + for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) { + FailureOr returnOperandIndex = getDirectReturnOperandIndex(cast(result)); + if (failed(returnOperandIndex)) + return computeBatchOp.emitOpError( + "resultful compute_batch lowering currently requires each result to be used directly by func.return"); + + hostOutputTensors[resultIndex] = state.outputTensors[*returnOperandIndex](rewriter, loc); + result.replaceAllUsesWith(hostOutputTensors[resultIndex]); + } + } + SmallVector blockArgTypes; SmallVector blockArgLocs; - for (BlockArgument arg : oldBlock.getArguments()) { + unsigned inputArgLimit = 1 + computeBatchOp.getWeights().size() + computeBatchOp.getInputs().size(); + for (BlockArgument arg : oldBlock.getArguments().take_front(inputArgLimit)) { blockArgTypes.push_back(arg.getType()); blockArgLocs.push_back(arg.getLoc()); } @@ -183,6 +253,38 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& if (isa(op)) continue; + if (auto parallelOp = dyn_cast(op)) { + unsigned firstOutputArg = computeBatchOp.getOutputArgument(0).getArgNumber(); + for (Operation& nestedOp : parallelOp.getRegion().front()) { + auto insertSlice = dyn_cast(&nestedOp); + if (!insertSlice) + return parallelOp.emitOpError("expected only tensor.parallel_insert_slice in spat.in_parallel"); + + auto outputArg = dyn_cast(insertSlice.getDest()); + if (!outputArg || outputArg.getOwner() != &oldBlock) + return insertSlice.emitOpError("expected compute_batch output block argument destination"); + + unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg; + if (resultIndex >= hostOutputTensors.size()) + return insertSlice.emitOpError("result index out of range while lowering host batch output"); + + Value mappedSource = mapper.lookup(insertSlice.getSource()); + auto hostTarget = hostOutputTensors[resultIndex]; + auto hostTargetType = cast(hostTarget.getType()); + Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper); + Value zeroOffset = arith::ConstantIndexOp::create(rewriter, insertSlice.getLoc(), 0).getResult(); + pim::PimMemCopyDevToHostOp::create(rewriter, + insertSlice.getLoc(), + hostTarget.getType(), + hostTargetOffset, + zeroOffset, + hostTarget, + mappedSource, + getTensorSizeInBytesAttr(rewriter, mappedSource)); + } + continue; + } + if (auto sendBatchOp = dyn_cast(op)) { FailureOr> targetCoreIds = getConstantI32Values(sendBatchOp.getTargetCoreIds()); if (failed(targetCoreIds)) diff --git a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt index 1d004ca..990bad7 100644 --- a/src/PIM/Conversion/SpatialToPim/CMakeLists.txt +++ b/src/PIM/Conversion/SpatialToPim/CMakeLists.txt @@ -6,7 +6,6 @@ add_pim_library(OMSpatialToPim SpatialToPimPass.cpp BatchCoreLoweringPatterns.cpp ChannelLoweringPatterns.cpp - Cleanup.cpp Common.cpp ComputeLikeRegionUtils.cpp CoreLoweringPatterns.cpp diff --git a/src/PIM/Conversion/SpatialToPim/Cleanup.cpp b/src/PIM/Conversion/SpatialToPim/Cleanup.cpp deleted file mode 100644 index 2da4aa8..0000000 --- a/src/PIM/Conversion/SpatialToPim/Cleanup.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include "llvm/ADT/STLExtras.h" - -#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp" - -using namespace mlir; - -namespace onnx_mlir { - -LogicalResult erasePendingOps(SmallVectorImpl& pendingOps, IRRewriter& rewriter) { - while (!pendingOps.empty()) { - bool erasedAnyOp = false; - for (auto it = pendingOps.begin(); it != pendingOps.end();) { - Operation* opToRemove = *it; - if (!opToRemove->use_empty()) { - ++it; - continue; - } - - rewriter.eraseOp(opToRemove); - it = pendingOps.erase(it); - erasedAnyOp = true; - } - - if (erasedAnyOp) - continue; - - for (Operation* opToRemove : pendingOps) { - InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation"); - diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)"; - for (Operation* user : opToRemove->getUsers()) { - bool userPendingRemoval = llvm::is_contained(pendingOps, user); - opToRemove->emitRemark() << "remaining user `" << user->getName() << "`" - << (userPendingRemoval ? " is also pending removal" : " is not pending removal"); - } - } - return failure(); - } - - return success(); -} - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/Cleanup.hpp b/src/PIM/Conversion/SpatialToPim/Cleanup.hpp deleted file mode 100644 index 8935fe7..0000000 --- a/src/PIM/Conversion/SpatialToPim/Cleanup.hpp +++ /dev/null @@ -1,11 +0,0 @@ -#pragma once - -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Support/LLVM.h" - -namespace onnx_mlir { - -mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl& pendingOps, mlir::IRRewriter& rewriter); - -} // namespace onnx_mlir diff --git a/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp b/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp index 7caf4ed..b0e5af6 100644 --- a/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp +++ b/src/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.cpp @@ -141,152 +141,6 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override { - Location loc = constantOp.getLoc(); - - if (hasWeightAlways(constantOp)) - return failure(); - - if (!isa(constantOp->getParentOp())) - return failure(); - - if (llvm::all_of(constantOp->getUsers(), [](Operation* op) { - if (isa(op)) - return false; - if (isa(op->getParentOp())) - return true; - return false; - })) - return failure(); - - rewriter.setInsertionPoint(constantOp->getParentOfType()); - - auto constRankedTensorType = llvm::dyn_cast(constantOp.getType()); - - if (constRankedTensorType) { - mlir::MemRefType memRefType = - mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType()); - auto globalOp = createPrivateMemrefGlobalWithUniqueName(rewriter, - loc, - constantOp->getParentOfType(), - "const", - memRefType, - constantOp.getValueAttr(), - rewriter.getUnitAttr()); - std::string argName = globalOp.getSymName().str(); - - llvm::DenseMap mapSpatComputeToConst; - - for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) { - auto constUsers = constUses.getOwner(); - - if (auto spatCompute = llvm::dyn_cast(constUsers)) { - auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber()); - if (!inputIndex) - return failure(); - auto BBArgIndex = *inputIndex; - rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); - if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) { - auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); - auto toTensor = bufferization::ToTensorOp::create( - rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); - mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()}); - } - - replaceAndEraseDirectComputeLikeInput( - rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatComputeToConst[spatCompute.getOperation()]); - } - else if (auto spatComputeBatch = llvm::dyn_cast(constUsers)) { - auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber()); - if (!inputIndex) - return failure(); - auto BBArgIndex = *inputIndex; - rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); - if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) { - auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); - auto toTensor = bufferization::ToTensorOp::create( - rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); - mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()}); - } - - replaceAndEraseDirectComputeLikeInput(rewriter, - spatComputeBatch.getOperation(), - BBArgIndex, - mapSpatComputeToConst[spatComputeBatch.getOperation()]); - } - else { - { - - if (auto spatCompute = constUses.getOwner()->getParentOfType()) { - rewriter.setInsertionPoint(&spatCompute.getBody().front().front()); - if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) { - auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); - auto toTensor = bufferization::ToTensorOp::create( - rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); - mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()}); - } - - rewriter.startOpModification(spatCompute.getOperation()); - constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]); - rewriter.finalizeOpModification(spatCompute.getOperation()); - } - else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType()) { - rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front()); - if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) { - auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName); - auto toTensor = bufferization::ToTensorOp::create( - rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr()); - mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()}); - } - - rewriter.startOpModification(spatComputeBatch.getOperation()); - constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]); - rewriter.finalizeOpModification(spatComputeBatch.getOperation()); - } - } - } - } - } - else if (constantOp.getType().isIntOrIndexOrFloat()) { - Value hostConstant = constantOp.getResult(); - - for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) { - auto constUsers = constUses.getOwner(); - - if (auto spatCompute = llvm::dyn_cast(constUsers)) { - auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber()); - if (!inputIndex) - return failure(); - auto BBArgIndex = *inputIndex; - replaceAndEraseDirectComputeLikeInput(rewriter, spatCompute.getOperation(), BBArgIndex, hostConstant); - } - else if (auto spatComputeBatch = llvm::dyn_cast(constUsers)) { - auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber()); - if (!inputIndex) - return failure(); - auto BBArgIndex = *inputIndex; - replaceAndEraseDirectComputeLikeInput(rewriter, spatComputeBatch.getOperation(), BBArgIndex, hostConstant); - } - else if (constUsers->getParentOfType()) { - constUses.set(hostConstant); - } - else { - auto batchParent = constUsers->getParentOfType(); - assert(batchParent && "Global Constant used direcly not within a compute"); - constUses.set(hostConstant); - } - } - } - if (constantOp->use_empty()) - rewriter.eraseOp(constantOp); - return success(); - } -}; - // Materializes public function tensor inputs as globals so compute bodies can load them uniformly. struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -363,7 +217,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern( + patterns.add( patterns.getContext()); } diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index 7957262..a8807a7 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -14,7 +14,6 @@ #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/FoldUtils.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "llvm/ADT/StringRef.h" @@ -28,7 +27,6 @@ #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp" -#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp" @@ -67,6 +65,7 @@ private: LogicalResult allocateAndInitializeCoreLocalVariables(func::FuncOp funcOp, IRRewriter& rewriter); void markOpToRemove(Operation* op); + void eraseOpsToRemove(); void enlargeVMMOutTensorsToCrossbarSize(func::FuncOp funcOp, IRRewriter& rewriter); }; @@ -268,13 +267,7 @@ void SpatialToPimPass::runOnOperation() { enlargeVMMOutTensorsToCrossbarSize(funcOp, rewriter); replaceReturnWithOutputBuffers(returnOp, rewriter, returnPathState); - - SmallVector pendingRemovals(operationsToRemove.begin(), operationsToRemove.end()); - if (failed(erasePendingOps(pendingRemovals, rewriter))) { - funcOp.emitOpError("failed to erase obsolete Spatial ops after lowering to PIM"); - signalPassFailure(); - return; - } + eraseOpsToRemove(); RewritePatternSet finalTensorPackingPatterns(ctx); populateTensorPackingPatterns(finalTensorPackingPatterns); @@ -399,6 +392,13 @@ void SpatialToPimPass::markOpToRemove(Operation* op) { operationsToRemove.push_back(op); } +void SpatialToPimPass::eraseOpsToRemove() { + for (Operation* op : operationsToRemove) { + op->dropAllUses(); + op->erase(); + } +} + std::unique_ptr createSpatialToPimPass() { return std::make_unique(); } } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Pim/PimOpsVerify.cpp b/src/PIM/Dialect/Pim/PimOpsVerify.cpp index ad7139a..bd3c80c 100644 --- a/src/PIM/Dialect/Pim/PimOpsVerify.cpp +++ b/src/PIM/Dialect/Pim/PimOpsVerify.cpp @@ -1,4 +1,6 @@ +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Block.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/OpDefinition.h" @@ -6,6 +8,7 @@ #include "llvm/Support/LogicalResult.h" +#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -40,7 +43,18 @@ static bool isDefinedInsideRegion(Value value, Region& region) { static bool isConstantExternalValue(Value value) { Operation* definingOp = value.getDefiningOp(); - return definingOp && definingOp->hasTrait(); + if (!definingOp) + return false; + if (definingOp->hasTrait()) + return true; + + auto getGlobalOp = dyn_cast(definingOp); + if (!getGlobalOp) + return false; + + auto moduleOp = definingOp->getParentOfType(); + auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp); + return globalOp && globalOp.getConstant(); } static LogicalResult verifyOnlyConstantExternalValues(Operation* ownerOp, Region& region, StringRef kind) { diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index c597f47..8efd47d 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -120,6 +120,15 @@ static bool isSupportedLaneOffsetExpr(Value value, BlockArgument laneArg) { if (value == laneArg || isConstantIndexLike(value)) return true; + auto extractOp = value.getDefiningOp(); + if (extractOp) { + auto constantTensor = extractOp.getTensor().getDefiningOp(); + auto denseAttr = constantTensor ? dyn_cast(constantTensor.getValue()) : nullptr; + if (!denseAttr || denseAttr.getType().getRank() != 1 || extractOp.getIndices().size() != 1) + return false; + return isSupportedLaneOffsetExpr(extractOp.getIndices().front(), laneArg); + } + auto addOp = value.getDefiningOp(); if (!addOp) return false; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 73fdbd2..14a2184 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -1,23 +1,31 @@ +#include "MaterializeMergeSchedule.hpp" + #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/FoldUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include #include #include +#include #include #include +#include -#include "MaterializeMergeSchedule.hpp" #include "Scheduling/ComputeInstanceUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" @@ -25,1257 +33,1314 @@ using namespace mlir; namespace onnx_mlir { namespace spatial { - namespace { -using SpatCompute = spatial::SpatCompute; -using ProducerValueRef = spatial::ProducerValueRef; -using spatial::getComputeInstanceInputs; -using spatial::getComputeInstanceOutputTypes; -using spatial::getComputeInstanceOutputValues; -using spatial::getComputeInstanceTemplateBlock; -using spatial::getComputeInstanceWeights; -using spatial::getProducerValueRef; +using CpuId = size_t; +using ClassId = size_t; +using SlotId = size_t; -static Value createIndexConstant(Operation* anchorOp, int64_t value, OperationFolder& folder) { - return getOrCreateHostIndexConstant(anchorOp, value, folder); +struct ProducerKey { + ComputeInstance instance; + size_t resultIndex = 0; + + bool operator==(const ProducerKey& other) const { + return instance == other.instance && resultIndex == other.resultIndex; + } +}; + +struct ProducerKeyInfo { + static ProducerKey getEmptyKey() { + return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max()}; + } + + static ProducerKey getTombstoneKey() { + return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max()}; + } + + static unsigned getHashValue(const ProducerKey& key) { + return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.instance), key.resultIndex); + } + + static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; } +}; + +struct CpuSlotKey { + CpuId cpu = 0; + SlotId slot = 0; + + bool operator==(const CpuSlotKey& other) const { return cpu == other.cpu && slot == other.slot; } +}; + +struct CpuSlotKeyInfo { + static CpuSlotKey getEmptyKey() { + return {std::numeric_limits::max(), std::numeric_limits::max()}; + } + + static CpuSlotKey getTombstoneKey() { + return {std::numeric_limits::max() - 1, std::numeric_limits::max()}; + } + + static unsigned getHashValue(const CpuSlotKey& key) { return llvm::hash_combine(key.cpu, key.slot); } + + static bool isEqual(const CpuSlotKey& lhs, const CpuSlotKey& rhs) { return lhs == rhs; } +}; + +struct ClassSlotKey { + ClassId classId = 0; + SlotId slot = 0; + + bool operator==(const ClassSlotKey& other) const { return classId == other.classId && slot == other.slot; } +}; + +struct ClassSlotKeyInfo { + static ClassSlotKey getEmptyKey() { + return {std::numeric_limits::max(), std::numeric_limits::max()}; + } + + static ClassSlotKey getTombstoneKey() { + return {std::numeric_limits::max() - 1, std::numeric_limits::max()}; + } + + static unsigned getHashValue(const ClassSlotKey& key) { return llvm::hash_combine(key.classId, key.slot); } + + static bool isEqual(const ClassSlotKey& lhs, const ClassSlotKey& rhs) { return lhs == rhs; } +}; + +struct MaterializedClass { + ClassId id = 0; + SmallVector cpus; + Operation* op = nullptr; + Block* body = nullptr; + bool isBatch = false; + + DenseMap cpuToLane; + SmallVector weights; + SmallVector inputs; + SmallVector hostOutputs; + DenseMap weightArgs; + DenseMap inputArgs; + DenseMap hostOutputToResultIndex; +}; + +struct MaterializerState { + func::FuncOp func; + const MergeScheduleResult& schedule; + IRRewriter rewriter; + OperationFolder constantFolder; + int64_t& nextChannelId; + + SmallVector classes; + DenseMap cpuToClass; + DenseMap cpuSlotToInstance; + DenseSet materializedSlots; + + DenseMap, ProducerKeyInfo> producerDestClasses; + DenseMap, ProducerKeyInfo> availableValues; + DenseMap hostReplacements; + DenseSet oldComputeOps; + + MaterializerState(func::FuncOp func, + const MergeScheduleResult& schedule, + int64_t& nextChannelId) + : func(func), schedule(schedule), rewriter(func.getContext()), constantFolder(func.getContext()), + nextChannelId(nextChannelId) {} +}; + +bool isConstantLike(Value value) { + Operation* definingOp = value.getDefiningOp(); + return definingOp && definingOp->hasTrait(); } -static SmallVector createIndexConstants(Operation* anchorOp, ArrayRef values, OperationFolder& folder) { - SmallVector constants; - constants.reserve(values.size()); - for (int64_t value : values) - constants.push_back(createIndexConstant(anchorOp, value, folder)); - return constants; +bool isInsideOldCompute(Operation* op, const DenseSet& oldComputeOps) { + for (Operation* current = op; current; current = current->getParentOp()) + if (oldComputeOps.contains(current)) + return true; + return false; } -static SmallVector createIndexConstants(Operation* anchorOp, ArrayRef values, OperationFolder& folder) { - SmallVector constants; - constants.reserve(values.size()); - for (int32_t value : values) - constants.push_back(createIndexConstant(anchorOp, value, folder)); - return constants; +bool hasLiveExternalUse(Value value, const DenseSet& oldComputeOps); + +std::optional getConstantFirstSliceOffset(tensor::ExtractSliceOp extract) { + if (extract.getMixedOffsets().empty()) + return std::nullopt; + + OpFoldResult offset = extract.getMixedOffsets().front(); + if (auto attr = dyn_cast(offset)) { + auto intAttr = dyn_cast(attr); + if (!intAttr || intAttr.getInt() < 0) + return std::nullopt; + return static_cast(intAttr.getInt()); + } + + auto value = cast(offset); + if (auto constantIndex = value.getDefiningOp()) { + if (constantIndex.value() < 0) + return std::nullopt; + return static_cast(constantIndex.value()); + } + + APInt constantValue; + if (matchPattern(value, m_ConstantInt(&constantValue))) { + if (constantValue.isNegative()) + return std::nullopt; + return static_cast(constantValue.getZExtValue()); + } + + return std::nullopt; } -static Value createIndexTensorConstant(Operation* anchorOp, ArrayRef values, OperationFolder& folder) { - auto tensorType = - RankedTensorType::get({static_cast(values.size())}, IndexType::get(anchorOp->getContext())); - auto tensorAttr = DenseIntElementsAttr::get(tensorType, values); - return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder); +ProducerKey getBatchLaneProducerKey(SpatComputeBatch batch, + uint32_t laneStart, + uint32_t laneCount, + size_t resultIndex) { + return {{batch.getOperation(), laneStart, laneCount}, resultIndex}; } -static Value createIndexTupleTensorConstant( - Operation* anchorOp, int64_t tupleCount, int64_t tupleWidth, ArrayRef values, OperationFolder& folder) { - auto tensorType = RankedTensorType::get({tupleCount, tupleWidth}, IndexType::get(anchorOp->getContext())); - auto tensorAttr = DenseIntElementsAttr::get(tensorType, values); - return getOrCreateHostConstant(anchorOp, tensorAttr, tensorType, folder); +ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex) { + return getBatchLaneProducerKey(batch, 0, static_cast(batch.getLaneCount()), resultIndex); } -class MergeScheduleMaterializerImpl { +bool isWholeBatchProducerKey(ProducerKey key) { + auto batch = dyn_cast_or_null(key.instance.op); + return batch && batch.getNumResults() != 0 && key.instance.laneStart == 0 && + key.instance.laneCount == static_cast(batch.getLaneCount()); +} + +SmallVector expandWholeBatchProducerKey(ProducerKey key) { + if (!isWholeBatchProducerKey(key)) + return SmallVector {key}; + + auto batch = cast(key.instance.op); + SmallVector keys; + keys.reserve(batch.getLaneCount()); + for (uint32_t lane = 0; lane < static_cast(batch.getLaneCount()); ++lane) + keys.push_back(getBatchLaneProducerKey(batch, lane, 1, key.resultIndex)); + return keys; +} + +std::optional getContiguousProducerKeyForKeys(ArrayRef keys) { + if (keys.empty()) + return std::nullopt; + + ProducerKey first = keys.front(); + auto batch = dyn_cast_or_null(first.instance.op); + if (!batch) + return std::nullopt; + + uint32_t laneStart = first.instance.laneStart; + for (auto [index, key] : llvm::enumerate(keys)) { + if (key.instance.op != first.instance.op || key.resultIndex != first.resultIndex || key.instance.laneCount != 1) + return std::nullopt; + if (key.instance.laneStart != laneStart + static_cast(index)) + return std::nullopt; + } + + uint32_t laneCount = static_cast(keys.size()); + if (laneStart + laneCount > static_cast(batch.getLaneCount())) + return std::nullopt; + + return getBatchLaneProducerKey(batch, laneStart, laneCount, first.resultIndex); +} + +FailureOr getPackedBatchTensorType(Type laneType, size_t laneCount) { + auto tensorType = dyn_cast(laneType); + if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0) + return failure(); + + SmallVector shape(tensorType.getShape()); + shape[0] *= static_cast(laneCount); + return RankedTensorType::get(shape, tensorType.getElementType(), tensorType.getEncoding()); +} + +std::optional lookupAvailableValue(MaterializerState& state, ProducerKey key, ClassId classId) { + auto producerIt = state.availableValues.find(key); + if (producerIt == state.availableValues.end()) + return std::nullopt; + + auto valueIt = producerIt->second.find(classId); + if (valueIt == producerIt->second.end()) + return std::nullopt; + + return valueIt->second; +} + +std::optional getProducerKey(Value value, const ComputeInstance* consumerInstance = nullptr) { + Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return std::nullopt; + + while (auto extract = dyn_cast(definingOp)) { + Value source = extract.getSource(); + auto batch = dyn_cast_or_null(source.getDefiningOp()); + if (batch && batch.getNumResults() != 0) { + auto result = dyn_cast(source); + if (!result) + return std::nullopt; + + uint32_t laneStart = 0; + uint32_t laneCount = 1; + if (std::optional lane = getConstantFirstSliceOffset(extract)) { + laneStart = *lane; + } + else if (consumerInstance && isa(consumerInstance->op)) { + laneStart = consumerInstance->laneStart; + laneCount = consumerInstance->laneCount; + } + else { + return std::nullopt; + } + + if (laneStart + laneCount > static_cast(batch.getLaneCount())) + return std::nullopt; + + return getBatchLaneProducerKey(batch, laneStart, laneCount, result.getResultNumber()); + } + + value = source; + definingOp = value.getDefiningOp(); + if (!definingOp) + return std::nullopt; + } + + if (auto compute = dyn_cast(definingOp)) { + auto result = dyn_cast(value); + if (!result) + return std::nullopt; + return ProducerKey {{compute.getOperation(), 0, 1}, result.getResultNumber()}; + } + + if (auto batch = dyn_cast(definingOp)) { + auto result = dyn_cast(value); + if (!result) + return std::nullopt; + + if (batch.getNumResults() != 0) { + if (consumerInstance && isa(consumerInstance->op)) + return getBatchLaneProducerKey(batch, + consumerInstance->laneStart, + consumerInstance->laneCount, + result.getResultNumber()); + + return getWholeBatchProducerKey(batch, result.getResultNumber()); + } + + return ProducerKey {getBatchChunkForLane(batch, result.getResultNumber()), 0}; + } + + return std::nullopt; +} + +class CpuUnionFind { public: - explicit MergeScheduleMaterializerImpl(func::FuncOp funcOp) - : func(funcOp), - loc(funcOp.getLoc()), - returnOp(cast(funcOp.getBody().front().getTerminator())), - constantFolder(funcOp.getContext()) {} + void insert(CpuId cpu) { parent.try_emplace(cpu, cpu); } - LogicalResult run(const MergeScheduleResult& scheduleResult, int64_t& nextChannelIdRef) { - schedule = &scheduleResult; - nextChannelId = &nextChannelIdRef; + CpuId find(CpuId cpu) { + insert(cpu); + CpuId p = parent.lookup(cpu); + if (p == cpu) + return cpu; + CpuId root = find(p); + parent[cpu] = root; + return root; + } - collectScheduledTasks(); - buildTaskIndex(); - collectExternalInputsAndWeights(); - planRemoteChannels(); - planReceiveReordering(); - createCpuComputeOps(); - if (failed(cloneTaskBodies())) - return failure(); - replaceExternalUses(); - if (failed(eraseOldScheduledOps())) - return failure(); - return success(); + void unite(CpuId lhs, CpuId rhs) { + CpuId lhsRoot = find(lhs); + CpuId rhsRoot = find(rhs); + if (lhsRoot == rhsRoot) + return; + if (rhsRoot < lhsRoot) + std::swap(lhsRoot, rhsRoot); + parent[rhsRoot] = lhsRoot; } private: - struct ScheduledTask { - ComputeInstance computeInstance; - size_t cpu = 0; - size_t orderWithinCpu = 0; - }; + DenseMap parent; +}; - struct ChannelInfo { - int64_t channelId = -1; - int32_t sourceCoreId = -1; - int32_t targetCoreId = -1; - }; +LogicalResult buildEquivalenceClasses(MaterializerState& state) { + DenseSet usedCpus; + for (const auto& entry : state.schedule.cpuToLastComputeMap) + usedCpus.insert(entry.first); + for (const auto& entry : state.schedule.computeToCpuMap) + usedCpus.insert(entry.second); - struct CpuProgram { - Operation* op = nullptr; - DenseMap externalInputMap; - DenseMap weightToIndex; - }; + CpuUnionFind unionFind; + for (CpuId cpu : usedCpus) + unionFind.insert(cpu); - using ProgramKey = size_t; // Represents the "Leader" CPU - DenseMap> batchedCpus; - SmallVector orderedPrograms; - - struct RemoteSendInfo { - ChannelInfo channelInfo; - ComputeInstance consumer; - size_t inputIndex = 0; - size_t consumerOrder = 0; - size_t sourceOrder = 0; - bool isTensorInput = false; - }; - - struct RemoteReceiveEntry { - ChannelInfo channelInfo; - ComputeInstance consumer; - size_t inputIndex = 0; - size_t sourceOrder = 0; - }; - - struct BatchYieldInfo { - Value yieldedValue; - tensor::ParallelInsertSliceOp insertSlice; - }; - - struct TensorChannelInfo { - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - SmallVector producerInstances; - size_t resultIndex = 0; - }; - - struct ExtractRowsSendRun { - SpatExtractRowsOp extractRows; - int64_t firstRow = 0; - SmallVector sendCounts; - SmallVector channelSourceTargetTuples; - }; - - struct ExtractSliceSendRun { - tensor::ExtractSliceOp extractSlice; - SmallVector baseOffsets; - unsigned varyingDim = 0; - int64_t offsetStep = 0; - SmallVector sendCounts; - SmallVector channelSourceTargetTuples; - }; - - static uint64_t getRemoteSendPairKey(const ChannelInfo& channelInfo) { - return (static_cast(static_cast(channelInfo.sourceCoreId)) << 32) - | static_cast(channelInfo.targetCoreId); + for (const auto& [cpu, equivalentCpus] : state.schedule.equivalentClass) { + if (!usedCpus.contains(cpu)) + continue; + for (CpuId equivalentCpu : equivalentCpus) + if (usedCpus.contains(equivalentCpu)) + unionFind.unite(cpu, equivalentCpu); } - static ProgramKey getProgramKey(const ScheduledTask& task) { return task.cpu; } + DenseMap> groupsByRoot; + for (CpuId cpu : usedCpus) + groupsByRoot[unionFind.find(cpu)].push_back(cpu); - static bool isResultfulBatchInstance(const ComputeInstance& instance) { - auto batch = dyn_cast(instance.op); - return batch && batch.getNumResults() != 0; - } + SmallVector roots; + roots.reserve(groupsByRoot.size()); + for (const auto& entry : groupsByRoot) + roots.push_back(entry.first); + llvm::sort(roots); - static SmallVector buildPrefixSums(ArrayRef counts) { - SmallVector prefixSums; - prefixSums.reserve(counts.size() + 1); - prefixSums.push_back(0); - int64_t running = 0; - for (int64_t count : counts) { - running += count; - prefixSums.push_back(running); + state.classes.reserve(roots.size()); + for (CpuId root : roots) { + MaterializedClass materializedClass; + materializedClass.id = state.classes.size(); + materializedClass.cpus = groupsByRoot.lookup(root); + llvm::sort(materializedClass.cpus); + materializedClass.isBatch = materializedClass.cpus.size() > 1; + for (auto [lane, cpu] : llvm::enumerate(materializedClass.cpus)) { + materializedClass.cpuToLane[cpu] = static_cast(lane); + state.cpuToClass[cpu] = materializedClass.id; } - return prefixSums; + state.classes.push_back(std::move(materializedClass)); } - FailureOr> collectResultfulBatchYieldInfo(SpatComputeBatch batch) { - Block& block = batch.getBody().front(); - auto inParallel = dyn_cast(block.getTerminator()); - if (!inParallel) - return failure(); - - SmallVector resultInfo(batch.getNumResults()); - DenseMap resultIndexByOutputArg; - for (unsigned resultIndex = 0; resultIndex < batch.getNumResults(); ++resultIndex) - resultIndexByOutputArg[batch.getOutputArgument(resultIndex)] = resultIndex; - - for (Operation& op : inParallel.getRegion().front()) { - auto insertSlice = dyn_cast(&op); - if (!insertSlice) - return failure(); - auto outputArg = dyn_cast(insertSlice.getDest()); - auto resultIndexIt = resultIndexByOutputArg.find(outputArg); - if (resultIndexIt == resultIndexByOutputArg.end()) - return failure(); - resultInfo[resultIndexIt->second] = {insertSlice.getSource(), insertSlice}; - } - - if (llvm::any_of(resultInfo, [](const BatchYieldInfo& info) { return !info.yieldedValue; })) - return failure(); - return resultInfo; + for (const auto& [instance, cpu] : state.schedule.computeToCpuMap) { + auto slotIt = state.schedule.computeToCpuSlotMap.find(instance); + if (slotIt == state.schedule.computeToCpuSlotMap.end()) + return instance.op->emitError("schedule materialization expected a CPU slot for every compute instance"); + state.cpuSlotToInstance[{cpu, slotIt->second}] = instance; + state.oldComputeOps.insert(instance.op); } - SmallVector getTaskOutputTypes(const ComputeInstance& instance) { - if (!isResultfulBatchInstance(instance)) - return getComputeInstanceOutputTypes(instance); + return success(); +} - auto batch = cast(instance.op); - FailureOr> yieldInfo = collectResultfulBatchYieldInfo(batch); - if (failed(yieldInfo)) - return {}; +LogicalResult collectHostOutputs(MaterializerState& state) { + DenseSet seenOutputs; + for (const ComputeInstance& instance : state.schedule.dominanceOrderCompute) { + auto cpuIt = state.schedule.computeToCpuMap.find(instance); + if (cpuIt == state.schedule.computeToCpuMap.end()) + return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); - SmallVector outputTypes; - outputTypes.reserve(yieldInfo->size()); - for (const BatchYieldInfo& info : *yieldInfo) - outputTypes.push_back(info.yieldedValue.getType()); - return outputTypes; - } - - bool tryCollectExtractRowsSendRun(ArrayRef> sendInfosByResult, - ArrayRef taskYieldValues, - size_t startIndex, - ExtractRowsSendRun& run, - size_t& nextIndex) { - auto firstResult = dyn_cast(taskYieldValues[startIndex]); - auto extractRows = firstResult ? dyn_cast(firstResult.getOwner()) : nullptr; - if (!extractRows || sendInfosByResult[startIndex].empty()) - return false; - - auto inputType = dyn_cast(extractRows.getInput().getType()); - auto rowType = dyn_cast(taskYieldValues[startIndex].getType()); - if (!inputType || !rowType || !inputType.hasStaticShape() || !rowType.hasStaticShape() || inputType.getRank() != 2 - || rowType.getRank() != 2) - return false; - - run = {}; - run.extractRows = extractRows; - run.firstRow = firstResult.getResultNumber(); - - unsigned expectedRow = firstResult.getResultNumber(); - size_t index = startIndex; - while (index < taskYieldValues.size()) { - auto result = dyn_cast(taskYieldValues[index]); - if (!result || result.getOwner() != extractRows.getOperation() || result.getResultNumber() != expectedRow) - break; - - const SmallVector& sendInfos = sendInfosByResult[index]; - run.sendCounts.push_back(static_cast(sendInfos.size())); - for (const RemoteSendInfo& sendInfo : sendInfos) { - run.channelSourceTargetTuples.push_back(sendInfo.channelInfo.channelId); - run.channelSourceTargetTuples.push_back(sendInfo.channelInfo.sourceCoreId); - run.channelSourceTargetTuples.push_back(sendInfo.channelInfo.targetCoreId); - } - - ++index; - ++expectedRow; - } - - nextIndex = index; - return run.sendCounts.size() > 1 && run.channelSourceTargetTuples.size() > 3; - } - - bool tryCollectExtractSliceSendRun(ArrayRef> sendInfosByResult, - ArrayRef taskYieldValues, - size_t startIndex, - ExtractSliceSendRun& run, - size_t& nextIndex) { - auto firstSlice = taskYieldValues[startIndex].getDefiningOp(); - if (!firstSlice || sendInfosByResult[startIndex].empty()) - return false; - if (llvm::any_of(firstSlice.getStaticOffsets(), ShapedType::isDynamic) - || llvm::any_of(firstSlice.getStaticSizes(), ShapedType::isDynamic) - || llvm::any_of(firstSlice.getStaticStrides(), ShapedType::isDynamic)) - return false; - - auto sourceType = dyn_cast(firstSlice.getSourceType()); - auto resultType = dyn_cast(firstSlice.getResultType()); - if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()) - return false; - - ArrayRef firstOffsets = firstSlice.getStaticOffsets(); - ArrayRef staticSizes = firstSlice.getStaticSizes(); - ArrayRef staticStrides = firstSlice.getStaticStrides(); - - std::optional varyingDim; - int64_t offsetStep = 0; - run = {}; - run.extractSlice = firstSlice; - run.baseOffsets.assign(firstOffsets.begin(), firstOffsets.end()); - - size_t index = startIndex; - while (index < taskYieldValues.size()) { - auto slice = taskYieldValues[index].getDefiningOp(); - if (!slice || slice.getSource() != firstSlice.getSource() || slice.getResultType() != firstSlice.getResultType() - || slice.getStaticSizes() != staticSizes || slice.getStaticStrides() != staticStrides - || llvm::any_of(slice.getStaticOffsets(), ShapedType::isDynamic)) - break; - - ArrayRef offsets = slice.getStaticOffsets(); - if (index != startIndex) { - SmallVector differingDims; - for (auto [dim, pair] : llvm::enumerate(llvm::zip(firstOffsets, offsets))) - if (std::get<0>(pair) != std::get<1>(pair)) - differingDims.push_back(dim); - if (differingDims.size() != 1) - break; - - unsigned dim = differingDims.front(); - int64_t expectedOffset = firstOffsets[dim] + static_cast(index - startIndex) * offsetStep; - if (!varyingDim) { - varyingDim = dim; - offsetStep = offsets[dim] - firstOffsets[dim]; - expectedOffset = offsets[dim]; - } - if (offsetStep <= 0 || *varyingDim != dim || offsets[dim] != expectedOffset) - break; - } - - const SmallVector& sendInfos = sendInfosByResult[index]; - run.sendCounts.push_back(static_cast(sendInfos.size())); - for (const RemoteSendInfo& sendInfo : sendInfos) { - run.channelSourceTargetTuples.push_back(sendInfo.channelInfo.channelId); - run.channelSourceTargetTuples.push_back(sendInfo.channelInfo.sourceCoreId); - run.channelSourceTargetTuples.push_back(sendInfo.channelInfo.targetCoreId); - } - - ++index; - } - - nextIndex = index; - if (!varyingDim) - return false; - run.varyingDim = *varyingDim; - run.offsetStep = offsetStep; - return run.sendCounts.size() > 1 && run.channelSourceTargetTuples.size() > 3; - } - - void emitInnerSendLoop(Operation* hostAnchor, - IRRewriter& rewriter, - Value sliceValue, - Value lower, - Value upper, - Value channelSourceTargetTuples) { - Value step = getOrCreateHostIndexConstant(hostAnchor, 1, constantFolder); - auto innerLoop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {}); - rewriter.setInsertionPointToStart(innerLoop.getBody()); - Value sendIndex = innerLoop.getInductionVar(); - Value tupleChannelIndex = getOrCreateHostIndexConstant(hostAnchor, 0, constantFolder); - Value tupleSourceIndex = getOrCreateHostIndexConstant(hostAnchor, 1, constantFolder); - Value tupleTargetIndex = getOrCreateHostIndexConstant(hostAnchor, 2, constantFolder); - Value channelIdIndex = - tensor::ExtractOp::create(rewriter, loc, channelSourceTargetTuples, ValueRange {sendIndex, tupleChannelIndex}); - Value sourceCoreIdIndex = - tensor::ExtractOp::create(rewriter, loc, channelSourceTargetTuples, ValueRange {sendIndex, tupleSourceIndex}); - Value targetCoreIdIndex = - tensor::ExtractOp::create(rewriter, loc, channelSourceTargetTuples, ValueRange {sendIndex, tupleTargetIndex}); - spatial::SpatChannelSendOp::create(rewriter, loc, channelIdIndex, sourceCoreIdIndex, targetCoreIdIndex, sliceValue); - rewriter.setInsertionPointAfter(innerLoop); - } - - void emitExtractRowsSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractRowsSendRun& run) { - SmallVector prefixSums = buildPrefixSums(run.sendCounts); - Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder); - Value channelSourceTargetTuples = - createIndexTupleTensorConstant(hostAnchor, - static_cast(run.channelSourceTargetTuples.size() / 3), - 3, - run.channelSourceTargetTuples, - constantFolder); - - Value lower = getOrCreateHostIndexConstant(hostAnchor, 0, constantFolder); - Value upper = getOrCreateHostIndexConstant(hostAnchor, static_cast(run.sendCounts.size()), constantFolder); - Value step = getOrCreateHostIndexConstant(hostAnchor, 1, constantFolder); - auto outerLoop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {}); - rewriter.setInsertionPointToStart(outerLoop.getBody()); - - Value rowIndex = outerLoop.getInductionVar(); - if (run.firstRow != 0) { - Value firstRow = getOrCreateHostIndexConstant(hostAnchor, run.firstRow, constantFolder); - rowIndex = arith::AddIOp::create(rewriter, loc, rowIndex, firstRow); - } - - auto rowType = cast(run.extractRows.getResult(0).getType()); - int64_t rowHeight = rowType.getShape()[0]; - if (rowHeight != 1) { - Value rowHeightValue = getOrCreateHostIndexConstant(hostAnchor, rowHeight, constantFolder); - rowIndex = arith::MulIOp::create(rewriter, loc, rowIndex, rowHeightValue); - } - - auto inputType = cast(run.extractRows.getInput().getType()); - SmallVector offsets = {rowIndex, rewriter.getIndexAttr(0)}; - SmallVector sizes = {rewriter.getIndexAttr(rowType.getShape()[0]), - rewriter.getIndexAttr(inputType.getShape()[1])}; - SmallVector strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; - Value extractedRow = - tensor::ExtractSliceOp::create(rewriter, loc, rowType, run.extractRows.getInput(), offsets, sizes, strides) - .getResult(); - - Value nextRowIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step); - Value innerLower = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()}); - Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextRowIndex}); - emitInnerSendLoop(hostAnchor, rewriter, extractedRow, innerLower, innerUpper, channelSourceTargetTuples); - rewriter.setInsertionPointAfter(outerLoop); - } - - void emitExtractSliceSendRun(Operation* hostAnchor, IRRewriter& rewriter, ExtractSliceSendRun& run) { - SmallVector prefixSums = buildPrefixSums(run.sendCounts); - Value prefixTensor = createIndexTensorConstant(hostAnchor, prefixSums, constantFolder); - Value channelSourceTargetTuples = - createIndexTupleTensorConstant(hostAnchor, - static_cast(run.channelSourceTargetTuples.size() / 3), - 3, - run.channelSourceTargetTuples, - constantFolder); - - Value lower = getOrCreateHostIndexConstant(hostAnchor, 0, constantFolder); - Value upper = getOrCreateHostIndexConstant(hostAnchor, static_cast(run.sendCounts.size()), constantFolder); - Value step = getOrCreateHostIndexConstant(hostAnchor, 1, constantFolder); - auto outerLoop = scf::ForOp::create(rewriter, loc, lower, upper, step, ValueRange {}); - rewriter.setInsertionPointToStart(outerLoop.getBody()); - - SmallVector offsets; - offsets.reserve(run.baseOffsets.size()); - for (auto [dim, offset] : llvm::enumerate(run.baseOffsets)) { - if (dim != run.varyingDim) { - offsets.push_back(rewriter.getIndexAttr(offset)); - continue; - } - - Value varyingOffset = outerLoop.getInductionVar(); - if (run.offsetStep != 1) { - Value offsetStep = getOrCreateHostIndexConstant(hostAnchor, run.offsetStep, constantFolder); - varyingOffset = arith::MulIOp::create(rewriter, loc, varyingOffset, offsetStep); - } - if (offset != 0) { - Value baseOffset = getOrCreateHostIndexConstant(hostAnchor, offset, constantFolder); - varyingOffset = arith::AddIOp::create(rewriter, loc, varyingOffset, baseOffset); - } - offsets.push_back(varyingOffset); - } - - SmallVector sizes; - SmallVector strides; - for (int64_t size : run.extractSlice.getStaticSizes()) - sizes.push_back(rewriter.getIndexAttr(size)); - for (int64_t stride : run.extractSlice.getStaticStrides()) - strides.push_back(rewriter.getIndexAttr(stride)); - - Value extractedSlice = - tensor::ExtractSliceOp::create( - rewriter, loc, run.extractSlice.getResultType(), run.extractSlice.getSource(), offsets, sizes, strides) - .getResult(); - - Value nextSliceIndex = arith::AddIOp::create(rewriter, loc, outerLoop.getInductionVar(), step); - Value innerLower = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {outerLoop.getInductionVar()}); - Value innerUpper = tensor::ExtractOp::create(rewriter, loc, prefixTensor, ValueRange {nextSliceIndex}); - emitInnerSendLoop(hostAnchor, rewriter, extractedSlice, innerLower, innerUpper, channelSourceTargetTuples); - rewriter.setInsertionPointAfter(outerLoop); - } - - bool tryEmitCompactSendLoops(Operation* hostAnchor, - IRRewriter& rewriter, - ArrayRef> sendInfosByResult, - ArrayRef taskYieldValues, - size_t startIndex, - size_t& nextIndex) { - ExtractRowsSendRun extractRowsRun; - if (tryCollectExtractRowsSendRun(sendInfosByResult, taskYieldValues, startIndex, extractRowsRun, nextIndex)) { - emitExtractRowsSendRun(hostAnchor, rewriter, extractRowsRun); - return true; - } - - ExtractSliceSendRun extractSliceRun; - if (tryCollectExtractSliceSendRun(sendInfosByResult, taskYieldValues, startIndex, extractSliceRun, nextIndex)) { - emitExtractSliceSendRun(hostAnchor, rewriter, extractSliceRun); - return true; - } - - return false; - } - - size_t getLoopableResultlessBatchRunLength(ArrayRef programTasks, size_t startIndex) { - const ScheduledTask& firstTask = programTasks[startIndex]; - auto batch = dyn_cast(firstTask.computeInstance.op); - if (!batch || batch.getNumResults() != 0 || firstTask.computeInstance.laneCount != 1) - return 1; - - SmallVector firstInputs = getComputeInstanceInputs(firstTask.computeInstance); - SmallVector firstWeights = getComputeInstanceWeights(firstTask.computeInstance); - size_t runLength = 1; - uint32_t nextLane = firstTask.computeInstance.laneStart + 1; - for (size_t index = startIndex + 1; index < programTasks.size(); ++index) { - const ScheduledTask& candidate = programTasks[index]; - if (candidate.computeInstance.op != firstTask.computeInstance.op || candidate.computeInstance.laneCount != 1 - || candidate.computeInstance.laneStart != nextLane) - break; - if (getComputeInstanceInputs(candidate.computeInstance) != firstInputs - || getComputeInstanceWeights(candidate.computeInstance) != firstWeights) - break; - - ++runLength; - ++nextLane; - } - return runLength; - } - - void collectScheduledTasks() { - for (ComputeInstance scheduledInstance : schedule->dominanceOrderCompute) { - oldComputeOps.insert(scheduledInstance.op); - scheduledTasks.push_back({scheduledInstance, - schedule->computeToCpuMap.lookup(scheduledInstance), - schedule->computeToCpuSlotMap.lookup(scheduledInstance)}); - } - } - - void buildTaskIndex() { - DenseSet seen; - for (const ScheduledTask& task : scheduledTasks) { - taskByComputeInstance[task.computeInstance] = task; - tasksByCpu[task.cpu].push_back(task); - seen.insert(task.cpu); - } - - SmallVector activeCpus(seen.begin(), seen.end()); - llvm::sort(activeCpus); - - DenseSet batched; - for (size_t cpu : activeCpus) { - if (batched.contains(cpu)) + MaterializedClass& materializedClass = state.classes[state.cpuToClass.lookup(cpuIt->second)]; + for (Value output : getComputeInstanceOutputValues(instance)) { + if (!hasLiveExternalUse(output, state.oldComputeOps) || !seenOutputs.insert(output).second) continue; - SmallVector batch; - batch.push_back(cpu); - batched.insert(cpu); - - // Group all equivalent CPUs into this batch - auto it = schedule->equivalentClass.find(cpu); - if (it != schedule->equivalentClass.end()) { - for (size_t eqCpu : it->second) - if (batched.insert(eqCpu).second) - batch.push_back(eqCpu); - } - - llvm::sort(batch); - size_t leader = batch.front(); - batchedCpus[leader] = batch; - orderedPrograms.push_back(leader); - } - - for (size_t cpu : activeCpus) { - llvm::stable_sort(tasksByCpu[cpu], [&](const ScheduledTask& lhs, const ScheduledTask& rhs) { - return lhs.orderWithinCpu < rhs.orderWithinCpu; - }); + materializedClass.hostOutputToResultIndex[output] = materializedClass.hostOutputs.size(); + materializedClass.hostOutputs.push_back(output); } } - void collectExternalInputsAndWeights() { - for (ProgramKey leader : orderedPrograms) { - const auto& batch = batchedCpus[leader]; - auto& thisCpuWeights = cpuWeights[leader]; - auto& thisCpuInputs = cpuExternalInputs[leader]; - auto& thisCpuOutputs = cpuExternalOutputs[leader]; + return success(); +} - // Process every lane sequentially to pack operands - for (size_t cpu : batch) { - DenseSet laneSeenWeights; - DenseSet laneSeenInputs; +void setOperandSegmentSizes(Operation* op, int weightCount, int inputCount) { + if (auto compute = dyn_cast(op)) { + compute.getProperties().setOperandSegmentSizes({weightCount, inputCount}); + return; + } + auto batch = cast(op); + batch.getProperties().setOperandSegmentSizes({weightCount, inputCount}); +} - for (const ScheduledTask& task : tasksByCpu[cpu]) { - for (Value weight : getComputeInstanceWeights(task.computeInstance)) - if (laneSeenWeights.insert(weight).second) - thisCpuWeights.push_back(weight); +void createEmptyMaterializedOps(MaterializerState& state) { + Location loc = state.func.getLoc(); + Block& funcBlock = state.func.getBody().front(); - auto taskInputs = getComputeInstanceInputs(task.computeInstance); - auto& remoteInputs = remoteInputsByTask[task.computeInstance]; - remoteInputs.resize(taskInputs.size()); - - for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { - bool isExternalInput = true; - if (auto producerRef = getProducerValueRef(input, &task.computeInstance)) { - auto producerIt = taskByComputeInstance.find(producerRef->instance); - if (producerIt != taskByComputeInstance.end()) { - isExternalInput = false; - if (producerIt->second.cpu != cpu) { - // Cross-core communication - ChannelInfo info; - info.channelId = (*nextChannelId)++; - info.sourceCoreId = static_cast(producerIt->second.cpu); - info.targetCoreId = static_cast(cpu); - - remoteInputs[inputIndex] = info; - auto& perResultChannels = remoteSendsByTask[producerRef->instance]; - if (perResultChannels.empty()) - perResultChannels.resize(getTaskOutputTypes(producerIt->second.computeInstance).size()); - - RemoteSendInfo sendInfo; - sendInfo.channelInfo = info; - sendInfo.consumer = task.computeInstance; - sendInfo.inputIndex = inputIndex; - sendInfo.consumerOrder = task.orderWithinCpu; - sendInfo.sourceOrder = 0; - sendInfo.isTensorInput = false; - perResultChannels[producerRef->resultIndex].push_back(sendInfo); - } - } - } - if (isExternalInput && laneSeenInputs.insert(input).second) - thisCpuInputs.push_back(input); - } - - // Define the logical return types based strictly on the Leader - if (cpu == leader) { - auto taskOutputs = getComputeInstanceOutputValues(task.computeInstance); - for (auto [resultIndex, output] : llvm::enumerate(taskOutputs)) { - bool hasExternalUser = false; - for (auto& use : output.getUses()) - if (!oldComputeOps.contains(use.getOwner())) - hasExternalUser = true; - if (hasExternalUser) - thisCpuOutputs.push_back({task.computeInstance, resultIndex}); - } - } - } - } + Operation* firstOldCompute = nullptr; + for (Operation& op : funcBlock) { + if (state.oldComputeOps.contains(&op)) { + firstOldCompute = &op; + break; } } - void planRemoteChannels() { - for (size_t cpu : orderedCpus) { - DenseMap nextSourceOrderByPair; - DenseMap lastConsumerOrderByPair; - for (const ScheduledTask& task : tasksByCpu[cpu]) { - auto sendsIt = remoteSendsByTask.find(task.computeInstance); - if (sendsIt == remoteSendsByTask.end()) + if (firstOldCompute) + state.rewriter.setInsertionPoint(firstOldCompute); + else + state.rewriter.setInsertionPointToStart(&funcBlock); + + for (MaterializedClass& materializedClass : state.classes) { + SmallVector resultTypes; + resultTypes.reserve(materializedClass.hostOutputs.size()); + for (Value output : materializedClass.hostOutputs) + resultTypes.push_back(output.getType()); + + if (!materializedClass.isBatch) { + auto compute = SpatCompute::create(state.rewriter, loc, TypeRange(resultTypes), ValueRange {}, ValueRange {}); + compute.getProperties().setOperandSegmentSizes({0, 0}); + compute->setAttr(onnx_mlir::kCoreIdAttrName, + state.rewriter.getI32IntegerAttr(static_cast(materializedClass.cpus.front()))); + Block* body = state.rewriter.createBlock(&compute.getBody()); + state.rewriter.setInsertionPointToEnd(body); + SmallVector placeholderOutputs; + placeholderOutputs.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto tensorType = dyn_cast(resultType); + if (!tensorType || !tensorType.hasStaticShape()) { + compute.emitOpError("host-facing materialized compute results must be static ranked tensors"); continue; - for (auto& sendInfos : sendsIt->second) { - for (RemoteSendInfo& sendInfo : sendInfos) { - if (sendInfo.isTensorInput) - continue; - uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo); - sendInfo.sourceOrder = nextSourceOrderByPair[pairKey]++; - auto [it, inserted] = lastConsumerOrderByPair.try_emplace(pairKey, sendInfo.consumerOrder); - if (!inserted) { - if (sendInfo.consumerOrder < it->second) - pairsNeedingReceiveReorder.insert(pairKey); - it->second = sendInfo.consumerOrder; - } - } } + placeholderOutputs.push_back( + tensor::EmptyOp::create(state.rewriter, loc, tensorType.getShape(), tensorType.getElementType()).getResult()); } + SpatYieldOp::create(state.rewriter, loc, ValueRange(placeholderOutputs)); + materializedClass.op = compute.getOperation(); + materializedClass.body = body; + state.rewriter.setInsertionPointAfter(compute.getOperation()); + continue; } + + auto batch = SpatComputeBatch::create(state.rewriter, + loc, + TypeRange(resultTypes), + state.rewriter.getI32IntegerAttr(static_cast(materializedClass.cpus.size())), + ValueRange {}, + ValueRange {}); + batch.getProperties().setOperandSegmentSizes({0, 0}); + SmallVector coreIds; + coreIds.reserve(materializedClass.cpus.size()); + for (CpuId cpu : materializedClass.cpus) + coreIds.push_back(static_cast(cpu)); + batch->setAttr(onnx_mlir::kCoreIdsAttrName, state.rewriter.getDenseI32ArrayAttr(coreIds)); + + SmallVector blockArgTypes {state.rewriter.getIndexType()}; + SmallVector blockArgLocs {loc}; + llvm::append_range(blockArgTypes, resultTypes); + blockArgLocs.append(resultTypes.size(), loc); + Block* body = state.rewriter.createBlock( + &batch.getBody(), batch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + state.rewriter.setInsertionPointToEnd(body); + if (resultTypes.empty()) { + SpatYieldOp::create(state.rewriter, loc, ValueRange {}); + } + else { + SpatInParallelOp::create(state.rewriter, loc); + } + materializedClass.op = batch.getOperation(); + materializedClass.body = body; + state.rewriter.setInsertionPointAfter(batch.getOperation()); + } +} + +BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) { + auto it = materializedClass.weightArgs.find(weight); + if (it != materializedClass.weightArgs.end()) + return it->second; + + unsigned weightIndex = materializedClass.weights.size(); + materializedClass.weights.push_back(weight); + + if (auto compute = dyn_cast(materializedClass.op)) { + compute.getWeightsMutable().append(ValueRange(weight)); + setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size()); + BlockArgument arg = materializedClass.body->insertArgument(weightIndex, weight.getType(), weight.getLoc()); + materializedClass.weightArgs[weight] = arg; + return arg; } - void planReceiveReordering() { - DenseMap> reorderedSendsByPair; - for (auto& taskSends : remoteSendsByTask) { - for (auto& sendInfos : taskSends.second) { - for (RemoteSendInfo& sendInfo : sendInfos) { - if (sendInfo.isTensorInput) - continue; - uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo); - if (pairsNeedingReceiveReorder.contains(pairKey)) - reorderedSendsByPair[pairKey].push_back(&sendInfo); - } - } - } + auto batch = cast(materializedClass.op); + batch.getWeightsMutable().append(ValueRange(weight)); + setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size()); + BlockArgument arg = materializedClass.body->insertArgument(1 + weightIndex, weight.getType(), weight.getLoc()); + materializedClass.weightArgs[weight] = arg; + return arg; +} - for (auto& pairSends : reorderedSendsByPair) { - llvm::stable_sort(pairSends.second, [](const RemoteSendInfo* lhs, const RemoteSendInfo* rhs) { - if (lhs->sourceOrder != rhs->sourceOrder) - return lhs->sourceOrder < rhs->sourceOrder; - return lhs->channelInfo.channelId < rhs->channelInfo.channelId; - }); - for (RemoteSendInfo* sendInfo : pairSends.second) { - int64_t channelId = (*nextChannelId)++; - sendInfo->channelInfo.channelId = channelId; - auto remoteInputsIt = remoteInputsByTask.find(sendInfo->consumer); - assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for reordered send"); - assert(sendInfo->inputIndex < remoteInputsIt->second.size() && "remote input index out of range"); - assert(remoteInputsIt->second[sendInfo->inputIndex] && "missing reordered remote input channel"); - remoteInputsIt->second[sendInfo->inputIndex]->channelId = channelId; - } - } +BlockArgument appendInput(MaterializerState& state, MaterializedClass& materializedClass, Value input) { + auto it = materializedClass.inputArgs.find(input); + if (it != materializedClass.inputArgs.end()) + return it->second; - for (const auto& taskSends : remoteSendsByTask) { - for (const auto& sendInfos : taskSends.second) { - for (const RemoteSendInfo& sendInfo : sendInfos) { - if (sendInfo.isTensorInput) - continue; - auto remoteInputsIt = remoteInputsByTask.find(sendInfo.consumer); - assert(remoteInputsIt != remoteInputsByTask.end() && "missing remote input for send"); - assert(sendInfo.inputIndex < remoteInputsIt->second.size() && "remote input index out of range"); - assert(remoteInputsIt->second[sendInfo.inputIndex] && "missing remote input channel"); - remoteInputsIt->second[sendInfo.inputIndex] = sendInfo.channelInfo; - } - } - } + materializedClass.inputs.push_back(input); + if (auto compute = dyn_cast(materializedClass.op)) + compute.getInputsMutable().append(ValueRange(input)); + else + cast(materializedClass.op).getInputsMutable().append(ValueRange(input)); + setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size()); - for (auto& taskSends : remoteSendsByTask) { - for (const auto& sendInfos : taskSends.second) { - for (const RemoteSendInfo& sendInfo : sendInfos) { - if (sendInfo.isTensorInput) - continue; - uint64_t pairKey = getRemoteSendPairKey(sendInfo.channelInfo); - if (!pairsNeedingReceiveReorder.contains(pairKey)) - continue; - size_t targetCpu = static_cast(sendInfo.channelInfo.targetCoreId); - receiveQueuesByCpu[targetCpu][pairKey].push_back( - {sendInfo.channelInfo, sendInfo.consumer, sendInfo.inputIndex, sendInfo.sourceOrder}); - } - } - } + BlockArgument arg = materializedClass.body->addArgument(input.getType(), input.getLoc()); + materializedClass.inputArgs[input] = arg; + return arg; +} - for (auto& cpuQueues : receiveQueuesByCpu) { - for (auto& pairQueue : cpuQueues.second) { - llvm::stable_sort(pairQueue.second, [](const RemoteReceiveEntry& lhs, const RemoteReceiveEntry& rhs) { - if (lhs.sourceOrder != rhs.sourceOrder) - return lhs.sourceOrder < rhs.sourceOrder; - return lhs.channelInfo.channelId < rhs.channelInfo.channelId; - }); - } - } - } +Value createIndexConstant(MaterializerState& state, Operation* anchor, int64_t value) { + return getOrCreateHostIndexConstant(anchor, value, state.constantFolder); +} - void createCpuComputeOps() { - IRRewriter rewriter(func.getContext()); - for (ProgramKey leader : orderedPrograms) { - const auto& batch = batchedCpus[leader]; - bool isBatch = batch.size() > 1; +SmallVector createIndexConstants(MaterializerState& state, Operation* anchor, ArrayRef values) { + SmallVector constants; + constants.reserve(values.size()); + for (int64_t value : values) + constants.push_back(createIndexConstant(state, anchor, value)); + return constants; +} - SmallVector resultTypes; - SmallVector packedResultTypes; +SmallVector createIndexConstants(MaterializerState& state, Operation* anchor, ArrayRef values) { + SmallVector widened; + widened.reserve(values.size()); + for (int32_t value : values) + widened.push_back(value); + return createIndexConstants(state, anchor, ArrayRef(widened)); +} - for (ProducerValueRef outputRef : cpuExternalOutputs[leader]) { - ScheduledTask task = taskByComputeInstance.at(outputRef.instance); - Type elemType = getTaskOutputTypes(task.computeInstance)[outputRef.resultIndex]; - resultTypes.push_back(elemType); - - if (isBatch) { - auto ranked = cast(elemType); - SmallVector shape; - shape.push_back(static_cast(batch.size())); - shape.append(ranked.getShape().begin(), ranked.getShape().end()); - packedResultTypes.push_back(RankedTensorType::get(shape, ranked.getElementType())); - } - } - - rewriter.setInsertionPoint(returnOp); - CpuProgram program; - - if (!isBatch) { - // Isolated CPU Execution - SmallVector operands; - operands.reserve(cpuWeights[leader].size() + cpuExternalInputs[leader].size()); - llvm::append_range(operands, cpuWeights[leader]); - llvm::append_range(operands, cpuExternalInputs[leader]); - - auto newCompute = SpatCompute::create(rewriter, loc, TypeRange(resultTypes), ValueRange(operands)); - newCompute.getProperties().setOperandSegmentSizes( - {static_cast(cpuWeights[leader].size()), static_cast(cpuExternalInputs[leader].size())}); - newCompute->setAttr(onnx_mlir::kCoreIdAttrName, rewriter.getI32IntegerAttr(static_cast(leader))); - - SmallVector blockArgTypes; - SmallVector blockArgLocs; - for (Value op : operands) { - blockArgTypes.push_back(op.getType()); - blockArgLocs.push_back(loc); - } - rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); - - program.op = newCompute; - for (auto [weightIndex, weight] : llvm::enumerate(cpuWeights[leader])) - program.weightToIndex[weight] = weightIndex; - for (auto [inputIndex, input] : llvm::enumerate(cpuExternalInputs[leader])) - program.externalInputMap[input] = newCompute.getInputArgument(inputIndex); - - for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[leader])) { - ScheduledTask task = taskByComputeInstance.at(outputRef.instance); - if (isResultfulBatchInstance(task.computeInstance)) { - auto oldBatch = cast(task.computeInstance.op); - auto& batchResults = resultfulBatchLaneResults[oldBatch.getOperation()]; - if (batchResults.empty()) - batchResults.resize(oldBatch.getNumResults()); - auto& laneResults = batchResults[outputRef.resultIndex]; - if (laneResults.empty()) - laneResults.resize(static_cast(oldBatch.getLaneCount())); - laneResults[task.computeInstance.laneStart] = newCompute.getResult(resultIndex); - continue; - } - oldToNewExternalValueMap[getComputeInstanceOutputValues(task.computeInstance)[outputRef.resultIndex]] = - newCompute.getResult(resultIndex); - } - } - else { - // Equivalence Class Batch Execution - auto newBatch = SpatComputeBatch::create(rewriter, - loc, - TypeRange(packedResultTypes), - rewriter.getI32IntegerAttr(batch.size()), - cpuWeights[leader], - cpuExternalInputs[leader]); - newBatch.getProperties().setOperandSegmentSizes( - {static_cast(cpuWeights[leader].size()), static_cast(cpuExternalInputs[leader].size())}); - - SmallVector coreIds; - for (size_t c : batch) - coreIds.push_back(static_cast(c)); - newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); - - SmallVector blockArgTypes; - SmallVector blockArgLocs; - blockArgTypes.push_back(rewriter.getIndexType()); // Lane ID Argument - blockArgLocs.push_back(loc); - - size_t weightsPerLane = cpuWeights[leader].size() / batch.size(); - size_t inputsPerLane = cpuExternalInputs[leader].size() / batch.size(); - - for (size_t i = 0; i < weightsPerLane; ++i) { - blockArgTypes.push_back(cpuWeights[leader][i].getType()); - blockArgLocs.push_back(loc); - } - for (size_t i = 0; i < inputsPerLane; ++i) { - blockArgTypes.push_back(cpuExternalInputs[leader][i].getType()); - blockArgLocs.push_back(loc); - } - for (Type t : packedResultTypes) { - blockArgTypes.push_back(t); - blockArgLocs.push_back(loc); - } // Dest tensors for InParallel Yield - - rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); - program.op = newBatch; - - // Host-side slice extractions - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(newBatch); - for (auto [laneIndex, cpu] : llvm::enumerate(batch)) { - for (auto [resultIndex, outputRef] : llvm::enumerate(cpuExternalOutputs[leader])) { - size_t taskIdx = 0; - for (size_t i = 0; i < tasksByCpu[leader].size(); ++i) { - if (tasksByCpu[leader][i].computeInstance == outputRef.instance) { - taskIdx = i; - break; - } - } - ComputeInstance laneInstance = tasksByCpu[cpu][taskIdx].computeInstance; - - auto ranked = cast(resultTypes[resultIndex]); - SmallVector offsets; - offsets.push_back(rewriter.getIndexAttr(laneIndex)); - SmallVector sizes; - sizes.push_back(rewriter.getIndexAttr(1)); - SmallVector strides; - strides.push_back(rewriter.getIndexAttr(1)); - - for (int64_t dim : ranked.getShape()) { - offsets.push_back(rewriter.getIndexAttr(0)); - sizes.push_back(rewriter.getIndexAttr(dim)); - strides.push_back(rewriter.getIndexAttr(1)); - } - auto slice = tensor::ExtractSliceOp::create( - rewriter, loc, ranked, newBatch.getResult(resultIndex), offsets, sizes, strides); - - if (isResultfulBatchInstance(laneInstance)) { - auto oldBatch = cast(laneInstance.op); - auto& batchResults = resultfulBatchLaneResults[oldBatch.getOperation()]; - if (batchResults.empty()) - batchResults.resize(oldBatch.getNumResults()); - auto& laneValues = batchResults[outputRef.resultIndex]; - if (laneValues.empty()) - laneValues.resize(static_cast(oldBatch.getLaneCount())); - laneValues[laneInstance.laneStart] = slice.getResult(); - } - else { - oldToNewExternalValueMap[getComputeInstanceOutputValues(laneInstance)[outputRef.resultIndex]] = - slice.getResult(); - } - } - } - - for (size_t i = 0; i < weightsPerLane; ++i) - program.weightToIndex[cpuWeights[leader][i]] = i; - for (size_t i = 0; i < inputsPerLane; ++i) - program.externalInputMap[cpuExternalInputs[leader][i]] = - newBatch.getBody().front().getArgument(1 + weightsPerLane + i); - } - cpuPrograms[leader] = std::move(program); - } - } - - FailureOr receiveThroughInput(IRRewriter& rewriter, - size_t cpu, - DenseMap& receiveQueueIndices, - DenseMap>& preReceivedInputsByTask, - const ChannelInfo& requestedChannelInfo, - ComputeInstance requestedConsumer, - size_t requestedInputIndex) { - uint64_t pairKey = getRemoteSendPairKey(requestedChannelInfo); - auto cpuQueuesIt = receiveQueuesByCpu.find(cpu); - if (cpuQueuesIt == receiveQueuesByCpu.end()) +FailureOr> getPeerInstances(MaterializerState& state, + const MaterializedClass& materializedClass, + SlotId slot) { + SmallVector peers; + peers.reserve(materializedClass.cpus.size()); + for (CpuId cpu : materializedClass.cpus) { + auto it = state.cpuSlotToInstance.find({cpu, slot}); + if (it == state.cpuSlotToInstance.end()) return failure(); - auto queueIt = cpuQueuesIt->second.find(pairKey); - if (queueIt == cpuQueuesIt->second.end()) + peers.push_back(it->second); + } + return peers; +} + +Value createOriginalLaneValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef peers, + Location loc) { + assert(!peers.empty() && "expected at least one peer instance"); + if (!materializedClass.isBatch) + return createIndexConstant(state, materializedClass.op, peers.front().laneStart); + + auto batch = cast(materializedClass.op); + bool identity = true; + for (auto [lane, peer] : llvm::enumerate(peers)) { + if (peer.laneCount != 1 || peer.laneStart != lane) { + identity = false; + break; + } + } + if (identity) + return batch.getLaneArgument(); + + bool affineWithBase = true; + int64_t base = static_cast(peers.front().laneStart); + for (auto [lane, peer] : llvm::enumerate(peers)) { + if (peer.laneCount != 1 || static_cast(peer.laneStart) != base + static_cast(lane)) { + affineWithBase = false; + break; + } + } + if (affineWithBase) { + if (base == 0) + return batch.getLaneArgument(); + Value baseValue = createIndexConstant(state, materializedClass.op, base); + return arith::AddIOp::create(state.rewriter, loc, batch.getLaneArgument(), baseValue).getResult(); + } + + SmallVector laneValues; + laneValues.reserve(peers.size()); + for (const ComputeInstance& peer : peers) + laneValues.push_back(APInt(64, peer.laneStart)); + + auto tableType = RankedTensorType::get({static_cast(peers.size())}, state.rewriter.getIndexType()); + auto tableAttr = DenseIntElementsAttr::get(tableType, laneValues); + Value table = arith::ConstantOp::create(state.rewriter, loc, tableType, tableAttr).getResult(); + return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {batch.getLaneArgument()}).getResult(); +} + +bool hasLiveExternalUse(Value value, const DenseSet& oldComputeOps) { + SmallVector worklist {value}; + DenseSet visited; + + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + if (!visited.insert(current).second) + continue; + + for (OpOperand& use : current.getUses()) { + Operation* owner = use.getOwner(); + if (isInsideOldCompute(owner, oldComputeOps)) + continue; + if (isa(owner)) { + for (Value result : owner->getResults()) + worklist.push_back(result); + continue; + } + return true; + } + } + + return false; +} + +void replaceLiveExternalUses(Value oldValue, Value replacement, const DenseSet& oldComputeOps) { + SmallVector uses; + for (OpOperand& use : oldValue.getUses()) + uses.push_back(&use); + + for (OpOperand* use : uses) { + Operation* owner = use->getOwner(); + if (isInsideOldCompute(owner, oldComputeOps)) + continue; + use->set(replacement); + } +} + +LogicalResult collectProducerDestinations(MaterializerState& state) { + for (const ComputeInstance& consumer : state.schedule.dominanceOrderCompute) { + auto cpuIt = state.schedule.computeToCpuMap.find(consumer); + if (cpuIt == state.schedule.computeToCpuMap.end()) + return consumer.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); + ClassId targetClass = state.cpuToClass.lookup(cpuIt->second); + + for (Value input : getComputeInstanceInputs(consumer)) { + std::optional producer = getProducerKey(input, &consumer); + if (!producer) + continue; + + for (ProducerKey producerKey : expandWholeBatchProducerKey(*producer)) { + auto producerCpuIt = state.schedule.computeToCpuMap.find(producerKey.instance); + if (producerCpuIt == state.schedule.computeToCpuMap.end()) + return consumer.op->emitError("schedule materialization found an input produced by an unscheduled compute"); + + ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second); + if (sourceClass == targetClass) + continue; + + state.producerDestClasses[producerKey].insert(targetClass); + } + } + } + + return success(); +} + +SmallVector getOutputKeysForPeers(ArrayRef peers, size_t resultIndex) { + SmallVector keys; + keys.reserve(peers.size()); + for (const ComputeInstance& peer : peers) + keys.push_back({peer, resultIndex}); + return keys; +} + +bool haveSameDestinationClasses(MaterializerState& state, ArrayRef keys) { + if (keys.empty()) + return true; + + auto firstIt = state.producerDestClasses.find(keys.front()); + DenseSet empty; + const DenseSet& first = firstIt == state.producerDestClasses.end() ? empty : firstIt->second; + for (ProducerKey key : keys.drop_front()) { + auto it = state.producerDestClasses.find(key); + const DenseSet& current = it == state.producerDestClasses.end() ? empty : it->second; + if (first.size() != current.size()) + return false; + for (ClassId classId : first) + if (!current.contains(classId)) + return false; + } + return true; +} + +SmallVector getSortedDestinationClasses(MaterializerState& state, ProducerKey key) { + SmallVector destinations; + auto it = state.producerDestClasses.find(key); + if (it == state.producerDestClasses.end()) + return destinations; + for (ClassId classId : it->second) + destinations.push_back(classId); + llvm::sort(destinations); + return destinations; +} + +Value appendReceive(MaterializerState& state, + MaterializedClass& targetClass, + Type type, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc) { + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + SmallVector channelIdValues = createIndexConstants(state, targetClass.op, channelIds); + SmallVector sourceCoreIdValues = createIndexConstants(state, targetClass.op, sourceCoreIds); + SmallVector targetCoreIdValues = createIndexConstants(state, targetClass.op, targetCoreIds); + + if (targetClass.isBatch) { + return SpatChannelReceiveBatchOp::create( + state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues) + .getOutput(); + } + + if (channelIds.size() != 1) { + return SpatChannelReceiveTensorOp::create( + state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues) + .getOutput(); + } + + return SpatChannelReceiveOp::create(state.rewriter, + loc, + type, + channelIdValues.front(), + sourceCoreIdValues.front(), + targetCoreIdValues.front()) + .getOutput(); +} + +Value appendHostReceive(MaterializerState& state, + MaterializedClass& sourceClass, + Type type, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc) { + state.rewriter.setInsertionPointAfter(sourceClass.op); + SmallVector channelIdValues = createIndexConstants(state, sourceClass.op, channelIds); + SmallVector sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds); + SmallVector targetCoreIdValues = createIndexConstants(state, sourceClass.op, targetCoreIds); + + if (sourceClass.isBatch) { + return SpatChannelReceiveTensorOp::create( + state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues) + .getOutput(); + } + + assert(channelIds.size() == 1 && "scalar host receive expects one channel"); + return SpatChannelReceiveOp::create(state.rewriter, + loc, + type, + channelIdValues.front(), + sourceCoreIdValues.front(), + targetCoreIdValues.front()) + .getOutput(); +} + +LogicalResult setHostOutputValue(MaterializerState& state, + MaterializedClass& sourceClass, + Value originalOutput, + Value payload) { + auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput); + if (resultIt == sourceClass.hostOutputToResultIndex.end()) + return sourceClass.op->emitError("missing host result slot for materialized output"); + + unsigned resultIndex = resultIt->second; + state.hostReplacements[originalOutput] = sourceClass.op->getResult(resultIndex); + + if (!sourceClass.isBatch) { + auto yieldOp = dyn_cast(sourceClass.body->getTerminator()); + if (!yieldOp) + return sourceClass.op->emitError("expected spat.yield terminator in materialized compute"); + if (resultIndex >= yieldOp.getNumOperands()) + return sourceClass.op->emitError("host result index out of range for materialized compute"); + + state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(resultIndex, payload); }); + return success(); + } + + auto batch = cast(sourceClass.op); + auto inParallelOp = dyn_cast(sourceClass.body->getTerminator()); + if (!inParallelOp) + return sourceClass.op->emitError("expected spat.in_parallel terminator in materialized compute_batch"); + + auto payloadType = dyn_cast(payload.getType()); + if (!payloadType || !payloadType.hasStaticShape()) + return sourceClass.op->emitError("host-facing compute_batch payload must be a static ranked tensor"); + + state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(payloadType.getRank()); + sizes.reserve(payloadType.getRank()); + strides.reserve(payloadType.getRank()); + offsets.push_back(batch.getLaneArgument()); + sizes.push_back(state.rewriter.getIndexAttr(1)); + strides.push_back(state.rewriter.getIndexAttr(1)); + for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) { + offsets.push_back(state.rewriter.getIndexAttr(0)); + sizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim))); + strides.push_back(state.rewriter.getIndexAttr(1)); + } + + tensor::ParallelInsertSliceOp::create( + state.rewriter, payload.getLoc(), payload, batch.getOutputArgument(resultIndex), offsets, sizes, strides); + return success(); +} + +void appendScalarSend(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + int64_t channelId, + int32_t sourceCoreId, + int32_t targetCoreId, + Location loc) { + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + Value channelIdValue = createIndexConstant(state, sourceClass.op, channelId); + Value sourceCoreIdValue = createIndexConstant(state, sourceClass.op, sourceCoreId); + Value targetCoreIdValue = createIndexConstant(state, sourceClass.op, targetCoreId); + SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload); +} + +void appendBatchSend(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc) { + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + SmallVector channelIdValues = createIndexConstants(state, sourceClass.op, channelIds); + SmallVector sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds); + SmallVector targetCoreIdValues = createIndexConstants(state, sourceClass.op, targetCoreIds); + SpatChannelSendBatchOp::create(state.rewriter, loc, channelIdValues, sourceCoreIdValues, targetCoreIdValues, payload); +} + +LogicalResult emitClassToClassCommunication(MaterializerState& state, + MaterializedClass& sourceClass, + MaterializedClass& targetClass, + ArrayRef keys, + Value payload, + Location loc) { + if (sourceClass.id == targetClass.id) { + for (ProducerKey key : keys) + state.availableValues[key][targetClass.id] = payload; + return success(); + } + + if (!sourceClass.isBatch && !targetClass.isBatch) { + int64_t channelId = state.nextChannelId++; + int32_t sourceCpu = static_cast(sourceClass.cpus.front()); + int32_t targetCpu = static_cast(targetClass.cpus.front()); + appendScalarSend(state, sourceClass, payload, channelId, sourceCpu, targetCpu, loc); + Value received = appendReceive(state, targetClass, payload.getType(), ArrayRef(channelId), + ArrayRef(sourceCpu), ArrayRef(targetCpu), loc); + for (ProducerKey key : keys) + state.availableValues[key][targetClass.id] = received; + return success(); + } + + if (!sourceClass.isBatch && targetClass.isBatch) { + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + channelIds.reserve(targetClass.cpus.size()); + sourceCoreIds.reserve(targetClass.cpus.size()); + targetCoreIds.reserve(targetClass.cpus.size()); + + for (CpuId targetCpu : targetClass.cpus) { + int64_t channelId = state.nextChannelId++; + channelIds.push_back(channelId); + sourceCoreIds.push_back(static_cast(sourceClass.cpus.front())); + targetCoreIds.push_back(static_cast(targetCpu)); + appendScalarSend(state, + sourceClass, + payload, + channelId, + static_cast(sourceClass.cpus.front()), + static_cast(targetCpu), + loc); + } + + Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); + for (ProducerKey key : keys) + state.availableValues[key][targetClass.id] = received; + return success(); + } + + if (sourceClass.isBatch && !targetClass.isBatch) { + std::optional packedKey = getContiguousProducerKeyForKeys(keys); + if (!packedKey) + return sourceClass.op->emitError( + "cannot materialize batch-to-scalar communication as concat because source lanes are not contiguous in send order"); + + FailureOr packedType = getPackedBatchTensorType(payload.getType(), keys.size()); + if (failed(packedType)) + return sourceClass.op->emitError( + "cannot materialize batch-to-scalar communication as concat for non-static ranked tensor payload"); + + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + channelIds.reserve(sourceClass.cpus.size()); + sourceCoreIds.reserve(sourceClass.cpus.size()); + targetCoreIds.reserve(sourceClass.cpus.size()); + + for (CpuId sourceCpu : sourceClass.cpus) { + channelIds.push_back(state.nextChannelId++); + sourceCoreIds.push_back(static_cast(sourceCpu)); + targetCoreIds.push_back(static_cast(targetClass.cpus.front())); + } + + appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); + Value received = appendReceive(state, targetClass, *packedType, channelIds, sourceCoreIds, targetCoreIds, loc); + state.availableValues[*packedKey][targetClass.id] = received; + return success(); + } + + if (sourceClass.isBatch && targetClass.isBatch) { + if (sourceClass.cpus.size() != targetClass.cpus.size()) + return sourceClass.op->emitError("cannot materialize batch communication between equivalence classes of different sizes"); + + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + channelIds.reserve(sourceClass.cpus.size()); + sourceCoreIds.reserve(sourceClass.cpus.size()); + targetCoreIds.reserve(targetClass.cpus.size()); + + for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { + channelIds.push_back(state.nextChannelId++); + sourceCoreIds.push_back(static_cast(sourceCpu)); + targetCoreIds.push_back(static_cast(targetClass.cpus[lane])); + } + + appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); + Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); + for (ProducerKey key : keys) + state.availableValues[key][targetClass.id] = received; + return success(); + } + + return sourceClass.op->emitError("unhandled materialized communication pattern"); +} + +LogicalResult emitHostCommunication(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + Value payload, + Value originalOutput, + Location loc) { + if (!hasLiveExternalUse(originalOutput, state.oldComputeOps)) + return success(); + + if (!sourceClass.hostOutputs.empty()) + return setHostOutputValue(state, sourceClass, originalOutput, payload); + + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + channelIds.reserve(sourceClass.cpus.size()); + sourceCoreIds.reserve(sourceClass.cpus.size()); + targetCoreIds.reserve(sourceClass.cpus.size()); + for (CpuId sourceCpu : sourceClass.cpus) { + channelIds.push_back(state.nextChannelId++); + sourceCoreIds.push_back(static_cast(sourceCpu)); + targetCoreIds.push_back(0); + } + + appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); + Value received = appendHostReceive(state, sourceClass, originalOutput.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); + state.hostReplacements[originalOutput] = received; + return success(); +} + +LogicalResult emitOutputFanout(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + Value payload, + Value originalOutput, + Location loc) { + if (keys.empty()) + return success(); + + if (!sourceClass.isBatch) { + for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front())) + if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) + return failure(); + if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc))) + return failure(); + state.availableValues[keys.front()][sourceClass.id] = payload; + return success(); + } + + if (!haveSameDestinationClasses(state, keys)) + return sourceClass.op->emitError( + "cannot materialize batched output whose lanes have different destination equivalence classes"); + + for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front())) + if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) return failure(); - auto& queue = queueIt->second; - size_t& queueIndex = receiveQueueIndices[pairKey]; - while (queueIndex < queue.size()) { - const RemoteReceiveEntry& entry = queue[queueIndex++]; - auto consumerTaskIt = taskByComputeInstance.find(entry.consumer); - if (consumerTaskIt == taskByComputeInstance.end()) - return failure(); - SmallVector consumerInputs = getComputeInstanceInputs(consumerTaskIt->second.computeInstance); - if (consumerInputs.size() <= entry.inputIndex) - return failure(); - Type inputType = consumerInputs[entry.inputIndex].getType(); - Value channelId = createIndexConstant(entry.consumer.op, entry.channelInfo.channelId, constantFolder); - Value sourceCoreId = createIndexConstant(entry.consumer.op, entry.channelInfo.sourceCoreId, constantFolder); - Value targetCoreId = createIndexConstant(entry.consumer.op, entry.channelInfo.targetCoreId, constantFolder); - auto receive = - spatial::SpatChannelReceiveOp::create(rewriter, loc, inputType, channelId, sourceCoreId, targetCoreId); + if (failed(emitHostCommunication(state, sourceClass, keys, payload, originalOutput, loc))) + return failure(); - auto& receivedInputs = preReceivedInputsByTask[entry.consumer]; - if (receivedInputs.size() <= entry.inputIndex) - receivedInputs.resize(entry.inputIndex + 1); - receivedInputs[entry.inputIndex] = receive.getResult(); + for (ProducerKey key : keys) + state.availableValues[key][sourceClass.id] = payload; + return success(); +} - if (entry.consumer == requestedConsumer && entry.inputIndex == requestedInputIndex) - return receive.getResult(); +FailureOr materializeWholeBatchInput(MaterializerState& state, + MaterializedClass& targetClass, + ProducerKey key, + Type resultType, + Location loc) { + auto batch = dyn_cast_or_null(key.instance.op); + auto resultTensorType = dyn_cast(resultType); + if (!batch || !resultTensorType || resultTensorType.getRank() == 0) + return failure(); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + + uint32_t batchLaneCount = static_cast(batch.getLaneCount()); + SmallVector fragments; + uint32_t lane = 0; + + while (lane < batchLaneCount) { + bool foundFragment = false; + + for (uint32_t laneCount = batchLaneCount - lane; laneCount != 0; --laneCount) { + ProducerKey candidate = getBatchLaneProducerKey(batch, lane, laneCount, key.resultIndex); + std::optional fragment = lookupAvailableValue(state, candidate, targetClass.id); + if (!fragment) + continue; + + fragments.push_back(*fragment); + lane += laneCount; + foundFragment = true; + break; } + + if (!foundFragment) + return failure(); + } + + if (fragments.empty()) + return failure(); + + Value result = fragments.front(); + if (fragments.size() != 1) + result = tensor::ConcatOp::create(state.rewriter, loc, 0, ValueRange(fragments)).getResult(); + + if (result.getType() != resultType) + result = tensor::CastOp::create(state.rewriter, loc, resultType, result).getResult(); + + state.availableValues[key][targetClass.id] = result; + return result; +} + +FailureOr resolveInputValue(MaterializerState& state, + MaterializedClass& targetClass, + Value input, + const ComputeInstance& consumerInstance) { + if (isConstantLike(input)) + return input; + + if (std::optional producer = getProducerKey(input, &consumerInstance)) { + if (std::optional value = lookupAvailableValue(state, *producer, targetClass.id)) + return *value; + + if (isWholeBatchProducerKey(*producer)) + return materializeWholeBatchInput(state, targetClass, *producer, input.getType(), consumerInstance.op->getLoc()); + return failure(); } - LogicalResult cloneTaskBodies() { - DenseMap> receiveQueueIndicesByCpu; - DenseMap>> preReceivedInputsByCpu; + return appendInput(state, targetClass, input); +} - auto lookupPreReceivedInput = [&](DenseMap>& preReceivedInputsByTask, - ComputeInstance consumer, - size_t inputIndex) -> std::optional { - auto inputsIt = preReceivedInputsByTask.find(consumer); - if (inputsIt == preReceivedInputsByTask.end() || inputsIt->second.size() <= inputIndex) - return std::nullopt; - Value value = inputsIt->second[inputIndex]; - if (!value) - return std::nullopt; - return value; - }; +void mapWeights(MaterializerState& state, + MaterializedClass& targetClass, + const ComputeInstance& instance, + IRMapping& mapper) { + Operation* op = instance.op; + if (auto compute = dyn_cast(op)) { + for (auto [index, weight] : llvm::enumerate(compute.getWeights())) + mapper.map(compute.getWeightArgument(index), appendWeight(state, targetClass, weight)); + return; + } - for (ProgramKey leader : orderedPrograms) { - const auto& batch = batchedCpus[leader]; - bool isBatch = batch.size() > 1; - CpuProgram& program = cpuPrograms[leader]; + auto batch = cast(op); + for (auto [index, weight] : llvm::enumerate(batch.getWeights())) + mapper.map(batch.getWeightArgument(index), appendWeight(state, targetClass, weight)); +} - IRRewriter rewriter(func.getContext()); - rewriter.setInsertionPointToEnd(&program.op->getRegion(0).front()); - - auto& receiveQueueIndices = receiveQueueIndicesByCpu[leader]; - auto& preReceivedInputsByTask = preReceivedInputsByCpu[leader]; - - ArrayRef leaderTasks = tasksByCpu[leader]; - - for (size_t taskIndex = 0; taskIndex < leaderTasks.size(); ++taskIndex) { - const ScheduledTask& task = leaderTasks[taskIndex]; - SmallVector taskInputs = getComputeInstanceInputs(task.computeInstance); - auto taskWeights = getComputeInstanceWeights(task.computeInstance); - Block& templateBlock = getComputeInstanceTemplateBlock(task.computeInstance); - - SmallVector resolvedInputs; - resolvedInputs.reserve(taskInputs.size()); - - auto remoteInputsIt = remoteInputsByTask.find(task.computeInstance); - - for (auto [inputIndex, input] : llvm::enumerate(taskInputs)) { - auto producerRef = getProducerValueRef(input, &task.computeInstance); - if (producerRef) { - auto producerIt = taskByComputeInstance.find(producerRef->instance); - if (producerIt != taskByComputeInstance.end()) { - if (producerIt->second.cpu == leader) { - auto producedIt = producedValuesByTask.find(producerRef->instance); - if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= producerRef->resultIndex) - return task.computeInstance.op->emitOpError("missing local producer value"); - resolvedInputs.push_back(producedIt->second[producerRef->resultIndex]); - continue; - } - - if (isBatch) { - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - - for (size_t cpu : batch) { - const ScheduledTask& laneTask = tasksByCpu[cpu][taskIndex]; - const ChannelInfo& info = *remoteInputsByTask[laneTask.computeInstance][inputIndex]; - channelIds.push_back(info.channelId); - sourceCoreIds.push_back(info.sourceCoreId); - targetCoreIds.push_back(info.targetCoreId); - } - - SmallVector cIds = createIndexConstants(program.op, channelIds, constantFolder); - SmallVector sIds = createIndexConstants(program.op, sourceCoreIds, constantFolder); - SmallVector tIds = createIndexConstants(program.op, targetCoreIds, constantFolder); - - auto recv = - spatial::SpatChannelReceiveBatchOp::create(rewriter, loc, input.getType(), cIds, sIds, tIds); - resolvedInputs.push_back(recv.getOutput()); - } - else { - const ChannelInfo& channelInfo = *remoteInputsIt->second[inputIndex]; - uint64_t pairKey = getRemoteSendPairKey(channelInfo); - - if (pairsNeedingReceiveReorder.contains(pairKey)) { - if (std::optional preReceived = - lookupPreReceivedInput(preReceivedInputsByTask, task.computeInstance, inputIndex)) { - resolvedInputs.push_back(*preReceived); - continue; - } - FailureOr received = receiveThroughInput(rewriter, - leader, - receiveQueueIndices, - preReceivedInputsByTask, - channelInfo, - task.computeInstance, - inputIndex); - if (failed(received)) - return task.computeInstance.op->emitOpError("failed to materialize reordered remote receive"); - resolvedInputs.push_back(*received); - continue; - } - - Value cId = createIndexConstant(program.op, channelInfo.channelId, constantFolder); - Value sId = createIndexConstant(program.op, channelInfo.sourceCoreId, constantFolder); - Value tId = createIndexConstant(program.op, channelInfo.targetCoreId, constantFolder); - auto receive = spatial::SpatChannelReceiveOp::create(rewriter, loc, input.getType(), cId, sId, tId); - resolvedInputs.push_back(receive.getResult()); - } - continue; - } - } - resolvedInputs.push_back(program.externalInputMap.at(input)); - } - - SmallVector taskYieldValues; - rewriter.setInsertionPointToEnd(&program.op->getRegion(0).front()); - - if (isa(task.computeInstance.op)) { - IRMapping mapper; - auto compute = cast(task.computeInstance.op); - for (auto [weightIndex, weight] : llvm::enumerate(taskWeights)) { - Value destArg = isBatch - ? cast(program.op).getWeightArgument(program.weightToIndex.at(weight)) - : cast(program.op).getWeightArgument(program.weightToIndex.at(weight)); - mapper.map(compute.getWeightArgument(weightIndex), destArg); - } - for (auto [inputIndex, input] : llvm::enumerate(resolvedInputs)) - mapper.map(compute.getInputArgument(inputIndex), input); - - for (Operation& op : templateBlock) { - if (auto yield = dyn_cast(&op)) { - for (Value yieldOperand : yield.getOperands()) - taskYieldValues.push_back(mapper.lookup(yieldOperand)); - continue; - } - rewriter.clone(op, mapper); - } - } - else { - // Include your existing isolated logic for preserving resultless spat.compute_batch here if needed - } - - producedValuesByTask[task.computeInstance] = taskYieldValues; - - if (auto sendsIt = remoteSendsByTask.find(task.computeInstance); sendsIt != remoteSendsByTask.end()) { - for (size_t resultIndex = 0; resultIndex < sendsIt->second.size();) { - const SmallVector& sendInfos = sendsIt->second[resultIndex]; - if (sendInfos.empty()) { - ++resultIndex; - continue; - } - - if (isBatch) { - size_t numSends = sendInfos.size(); - for (size_t s = 0; s < numSends; ++s) { - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - - for (size_t cpu : batch) { - const ScheduledTask& laneTask = tasksByCpu[cpu][taskIndex]; - const RemoteSendInfo& send = remoteSendsByTask[laneTask.computeInstance][resultIndex][s]; - channelIds.push_back(send.channelInfo.channelId); - sourceCoreIds.push_back(send.channelInfo.sourceCoreId); - targetCoreIds.push_back(send.channelInfo.targetCoreId); - } - - SmallVector cIds = createIndexConstants(program.op, channelIds, constantFolder); - SmallVector sIds = createIndexConstants(program.op, sourceCoreIds, constantFolder); - SmallVector tIds = createIndexConstants(program.op, targetCoreIds, constantFolder); - - spatial::SpatChannelSendBatchOp::create(rewriter, loc, cIds, sIds, tIds, taskYieldValues[resultIndex]); - } - ++resultIndex; - } - else { - size_t nextResultIndex = resultIndex + 1; - if (tryEmitCompactSendLoops( - program.op, rewriter, sendsIt->second, taskYieldValues, resultIndex, nextResultIndex)) { - resultIndex = nextResultIndex; - continue; - } - - Value producedValue = taskYieldValues[resultIndex]; - for (const RemoteSendInfo& sendInfo : sendInfos) { - Value cId = createIndexConstant(program.op, sendInfo.channelInfo.channelId, constantFolder); - Value sId = createIndexConstant(program.op, sendInfo.channelInfo.sourceCoreId, constantFolder); - Value tId = createIndexConstant(program.op, sendInfo.channelInfo.targetCoreId, constantFolder); - spatial::SpatChannelSendOp::create(rewriter, loc, cId, sId, tId, producedValue); - } - ++resultIndex; - } - } - } - } - - SmallVector yieldValues; - yieldValues.reserve(cpuExternalOutputs[leader].size()); - for (ProducerValueRef outputRef : cpuExternalOutputs[leader]) { - auto producedIt = producedValuesByTask.find(outputRef.instance); - if (producedIt == producedValuesByTask.end() || producedIt->second.size() <= outputRef.resultIndex) - return func.emitError("missing yielded external value during materialization"); - yieldValues.push_back(producedIt->second[outputRef.resultIndex]); - } - - if (isBatch) { - auto batchOp = cast(program.op); - auto inParallel = spatial::SpatInParallelOp::create(rewriter, loc); - Block* parallelBlock = rewriter.createBlock(&inParallel.getRegion()); - rewriter.setInsertionPointToEnd(parallelBlock); - - for (auto [resultIndex, yieldedVal] : llvm::enumerate(yieldValues)) { - auto destArg = batchOp.getOutputArgument(resultIndex); - auto destType = cast(destArg.getType()); - - SmallVector offsets; - offsets.push_back(batchOp.getLaneArgument()); - SmallVector sizes; - sizes.push_back(rewriter.getIndexAttr(1)); - SmallVector strides; - strides.push_back(rewriter.getIndexAttr(1)); - - for (int64_t dim : destType.getShape().drop_front()) { - offsets.push_back(rewriter.getIndexAttr(0)); - sizes.push_back(rewriter.getIndexAttr(dim)); - strides.push_back(rewriter.getIndexAttr(1)); - } - tensor::ParallelInsertSliceOp::create(rewriter, loc, yieldedVal, destArg, offsets, sizes, strides); - } - } - else { - spatial::SpatYieldOp::create(rewriter, loc, ValueRange(yieldValues)); - } +LogicalResult mapInputs(MaterializerState& state, + MaterializedClass& targetClass, + const ComputeInstance& instance, + IRMapping& mapper) { + Operation* op = instance.op; + if (auto compute = dyn_cast(op)) { + for (auto [index, input] : llvm::enumerate(compute.getInputs())) { + FailureOr mapped = resolveInputValue(state, targetClass, input, instance); + if (failed(mapped)) + return compute.emitOpError("failed to resolve materialized compute input"); + mapper.map(compute.getInputArgument(index), *mapped); } - return success(); } - void replaceExternalUses() { - for (auto [oldValue, newValue] : oldToNewExternalValueMap) { - for (auto& use : llvm::make_early_inc_range(oldValue.getUses())) - if (!oldComputeOps.contains(use.getOwner())) - use.assign(newValue); - } + auto batch = cast(op); + for (auto [index, input] : llvm::enumerate(batch.getInputs())) { + FailureOr mapped = resolveInputValue(state, targetClass, input, instance); + if (failed(mapped)) + return batch.emitOpError("failed to resolve materialized compute_batch input"); + mapper.map(batch.getInputArgument(index), *mapped); + } + return success(); +} - IRRewriter rewriter(func.getContext()); - for (auto& [op, resultLaneValues] : resultfulBatchLaneResults) { - auto batch = cast(op); - for (auto [resultIndex, laneValues] : llvm::enumerate(resultLaneValues)) { - bool hasNonScheduledUse = false; - for (Operation* user : batch.getResult(resultIndex).getUsers()) { - if (!oldComputeOps.contains(user)) { - hasNonScheduledUse = true; - break; - } - } - if (!hasNonScheduledUse) - continue; +SmallVector collectMappedBatchOutputs(SpatComputeBatch batch, IRMapping& mapper) { + SmallVector outputs(batch.getNumResults(), Value {}); + auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); + if (!inParallel) + return outputs; - if (laneValues.size() != static_cast(batch.getLaneCount()) - || llvm::any_of(laneValues, [](Value value) { return !value; })) { - batch.emitOpError("missing materialized lane result while rebuilding resultful compute_batch result"); - continue; - } + for (Operation& op : inParallel.getRegion().front()) { + auto insert = dyn_cast(&op); + if (!insert) + continue; - rewriter.setInsertionPoint(returnOp); - Value packedResult = - tensor::ConcatOp::create(rewriter, batch.getLoc(), /*dim=*/0, ValueRange(laneValues)).getResult(); - batch.getResult(resultIndex).replaceAllUsesWith(packedResult); - } - } + auto outputArg = dyn_cast(insert.getDest()); + if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) + continue; + + unsigned firstOutputArg = batch.getOutputArgument(0).getArgNumber(); + unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg; + if (resultIndex >= outputs.size()) + continue; + outputs[resultIndex] = mapper.lookupOrDefault(insert.getSource()); } - LogicalResult eraseOldScheduledOps() { - SmallVector orderedOpsToErase; - for (Operation& op : func.getBody().front()) - if (oldComputeOps.contains(&op)) - orderedOpsToErase.push_back(&op); + return outputs; +} - for (Operation* op : llvm::reverse(orderedOpsToErase)) { - SmallVector remainingUsers; - for (Value result : op->getResults()) - for (Operation* user : result.getUsers()) - remainingUsers.push_back(user); - if (!remainingUsers.empty()) { - InFlightDiagnostic diagnostic = op->emitOpError("still has uses during per-cpu merge cleanup") - << "; erase-set=" << (oldComputeOps.contains(op) ? "yes" : "no"); - for (Operation* user : remainingUsers) { - diagnostic.attachNote(user->getLoc()) - << "remaining user " << user->getName() << "; erase-set=" << (oldComputeOps.contains(user) ? "yes" : "no"); - } +FailureOr> cloneInstanceBody(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef peers) { + assert(!peers.empty() && "expected at least one peer instance"); + const ComputeInstance& instance = peers.front(); + Operation* sourceOp = instance.op; + Location loc = sourceOp->getLoc(); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + + IRMapping mapper; + if (auto batch = dyn_cast(sourceOp)) { + for (const ComputeInstance& peer : peers) { + if (peer.op != sourceOp) { + sourceOp->emitError("equivalence class slot contains different source compute_batch operations"); + return failure(); + } + if (peer.laneCount != 1) { + sourceOp->emitError("schedule materialization currently expects one original batch lane per CPU"); return failure(); } - op->erase(); } - - return success(); + mapper.map(batch.getLaneArgument(), createOriginalLaneValue(state, targetClass, peers, loc)); } - func::FuncOp func; - const MergeScheduleResult* schedule = nullptr; - int64_t* nextChannelId = nullptr; - Location loc; - func::ReturnOp returnOp; - OperationFolder constantFolder; + mapWeights(state, targetClass, instance, mapper); + if (failed(mapInputs(state, targetClass, instance, mapper))) + return failure(); - SmallVector scheduledTasks; - DenseSet oldComputeOps; - DenseMap taskByComputeInstance; - DenseMap> tasksByCpu; - DenseMap> tasksByProgram; - SmallVector orderedCpus; - DenseSet seenCpus; - DenseSet seenPrograms; - DenseMap>> remoteSendsByTask; - DenseMap>> remoteInputsByTask; - DenseMap, 4>> remoteTensorInputsByTask; - DenseMap> cpuExternalInputs; - DenseMap> cpuWeights; - DenseMap> cpuExternalOutputs; - DenseMap> seenExternalInputsByProgram; - DenseMap> seenWeightsByProgram; - DenseSet pairsNeedingReceiveReorder; - DenseMap>> receiveQueuesByCpu; - DenseMap cpuPrograms; - DenseMap oldToNewExternalValueMap; - DenseMap> producedValuesByTask; - DenseMap>> resultfulBatchLaneResults; -}; + Block& sourceBlock = getComputeInstanceTemplateBlock(instance); + for (Operation& op : sourceBlock.without_terminator()) { + Operation* cloned = state.rewriter.clone(op, mapper); + for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults())) + mapper.map(oldResult, newResult); + } + + if (auto compute = dyn_cast(sourceOp)) { + auto yield = dyn_cast_or_null(sourceBlock.getTerminator()); + if (!yield) { + compute.emitOpError("expected spat.yield terminator while materializing compute"); + return failure(); + } + + SmallVector outputs; + outputs.reserve(yield.getNumOperands()); + for (Value yielded : yield.getOutputs()) + outputs.push_back(mapper.lookupOrDefault(yielded)); + return outputs; + } + + auto batch = cast(sourceOp); + if (batch.getNumResults() == 0) + return SmallVector {}; + + SmallVector outputs = collectMappedBatchOutputs(batch, mapper); + for (Value output : outputs) + if (!output) { + batch.emitOpError("failed to recover yielded per-lane value for compute_batch result"); + return failure(); + } + return outputs; +} + +LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeInstance& instance) { + auto cpuIt = state.schedule.computeToCpuMap.find(instance); + if (cpuIt == state.schedule.computeToCpuMap.end()) + return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); + auto slotIt = state.schedule.computeToCpuSlotMap.find(instance); + if (slotIt == state.schedule.computeToCpuSlotMap.end()) + return instance.op->emitError("schedule materialization expected a CPU slot for every compute instance"); + + ClassId classId = state.cpuToClass.lookup(cpuIt->second); + ClassSlotKey classSlot {classId, slotIt->second}; + if (!state.materializedSlots.insert(classSlot).second) + return success(); + + MaterializedClass& targetClass = state.classes[classId]; + FailureOr> peers = getPeerInstances(state, targetClass, slotIt->second); + if (failed(peers)) + return instance.op->emitError("failed to collect peer compute instances for equivalence class slot"); + + FailureOr> materializedOutputs = cloneInstanceBody(state, targetClass, *peers); + if (failed(materializedOutputs)) + return failure(); + + SmallVector originalOutputs = getComputeInstanceOutputValues(instance); + if (materializedOutputs->size() != originalOutputs.size()) + return instance.op->emitError("materialized output count does not match original compute instance output count"); + + for (auto [resultIndex, zipped] : llvm::enumerate(llvm::zip(*materializedOutputs, originalOutputs))) { + Value materializedOutput = std::get<0>(zipped); + Value originalOutput = std::get<1>(zipped); + SmallVector keys = getOutputKeysForPeers(*peers, resultIndex); + if (failed(emitOutputFanout(state, targetClass, keys, materializedOutput, originalOutput, instance.op->getLoc()))) + return failure(); + } + + return success(); +} + +void replaceHostUses(MaterializerState& state) { + for (const auto& [oldValue, replacement] : state.hostReplacements) + replaceLiveExternalUses(oldValue, replacement, state.oldComputeOps); +} + +LogicalResult eraseOldComputeOps(MaterializerState& state) { + DenseSet seen; + for (const ComputeInstance& instance : state.schedule.dominanceOrderCompute) { + if (!seen.insert(instance.op).second) + continue; + instance.op->dropAllUses(); + instance.op->erase(); + } + return success(); +} } // namespace -LogicalResult -MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) { - return MergeScheduleMaterializerImpl(func).run(schedule, nextChannelId); +LogicalResult MergeScheduleMaterializer::run(func::FuncOp func, + const MergeScheduleResult& schedule, + int64_t& nextChannelId) { + if (schedule.dominanceOrderCompute.empty()) + return success(); + + MaterializerState state(func, schedule, nextChannelId); + if (failed(buildEquivalenceClasses(state))) + return failure(); + if (state.classes.empty()) + return success(); + + if (failed(collectHostOutputs(state))) + return failure(); + createEmptyMaterializedOps(state); + if (failed(collectProducerDestinations(state))) + return failure(); + + for (const ComputeInstance& instance : schedule.dominanceOrderCompute) + if (failed(materializeInstanceSlot(state, instance))) + return failure(); + + replaceHostUses(state); + if (failed(eraseOldComputeOps(state))) + return failure(); + + return success(); } } // namespace spatial diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.cpp index c3bae1d..cefaa1a 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/PostMergeCompaction.cpp @@ -267,212 +267,6 @@ bool areEquivalentForRebatch(SpatCompute lhs, SpatCompute rhs) { return lhsIt == lhsBlock.end() && rhsIt == rhsBlock.end(); } -struct BatchYieldInfo { - Value yieldedValue; - tensor::ParallelInsertSliceOp insertSlice; -}; - -static bool isHostOnlyBatchResultUser(Operation* user) { - return isa(user); -} - -static FailureOr> collectBatchYieldInfo(SpatComputeBatch batchOp) { - Block& block = batchOp.getBody().front(); - auto inParallel = dyn_cast(block.getTerminator()); - if (!inParallel) - return failure(); - - DenseMap batchYieldByOutputArg; - for (Operation& op : inParallel.getRegion().front()) { - auto insertSlice = dyn_cast(&op); - if (!insertSlice) - return failure(); - auto outputArg = dyn_cast(insertSlice.getDest()); - if (!outputArg || outputArg.getOwner() != &block) - return failure(); - batchYieldByOutputArg[outputArg] = {insertSlice.getSource(), insertSlice}; - } - return batchYieldByOutputArg; -} - -static FailureOr cloneBatchAsResultless(SpatComputeBatch batchOp, IRRewriter& rewriter) { - auto coreIdsAttr = batchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName); - if (!coreIdsAttr) - return failure(); - - Block& oldBlock = batchOp.getBody().front(); - rewriter.setInsertionPoint(batchOp); - auto newBatch = SpatComputeBatch::create(rewriter, - batchOp.getLoc(), - TypeRange {}, - rewriter.getI32IntegerAttr(batchOp.getLaneCount()), - batchOp.getWeights(), - batchOp.getInputs()); - newBatch.getProperties().setOperandSegmentSizes( - {static_cast(batchOp.getWeights().size()), static_cast(batchOp.getInputs().size())}); - newBatch->setAttr(onnx_mlir::kCoreIdsAttrName, coreIdsAttr); - - SmallVector blockArgTypes; - SmallVector blockArgLocs; - blockArgTypes.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size()); - blockArgLocs.reserve(1 + batchOp.getWeights().size() + batchOp.getInputs().size()); - blockArgTypes.push_back(batchOp.getLaneArgument().getType()); - blockArgLocs.push_back(batchOp.getLaneArgument().getLoc()); - for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex) { - blockArgTypes.push_back(batchOp.getWeightArgument(weightIndex).getType()); - blockArgLocs.push_back(batchOp.getWeightArgument(weightIndex).getLoc()); - } - for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex) { - blockArgTypes.push_back(batchOp.getInputArgument(inputIndex).getType()); - blockArgLocs.push_back(batchOp.getInputArgument(inputIndex).getLoc()); - } - - Block* newBlock = - rewriter.createBlock(&newBatch.getBody(), newBatch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); - rewriter.setInsertionPointToStart(newBlock); - - IRMapping mapper; - mapper.map(batchOp.getLaneArgument(), newBatch.getLaneArgument()); - for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex) - mapper.map(batchOp.getWeightArgument(weightIndex), newBatch.getWeightArgument(weightIndex)); - for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex) - mapper.map(batchOp.getInputArgument(inputIndex), newBatch.getInputArgument(inputIndex)); - - for (Operation& op : oldBlock.without_terminator()) { - Operation* cloned = rewriter.clone(op, mapper); - for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults())) - mapper.map(oldResult, newResult); - } - - return newBatch; -} - -static LogicalResult materializeBatchResultCommunication(func::FuncOp funcOp, int64_t& nextChannelId) { - IRRewriter rewriter(funcOp.getContext()); - OperationFolder constantFolder(funcOp.getContext()); - SmallVector batches(funcOp.getOps()); - - for (auto batchOp : batches) { - if (batchOp.getNumResults() == 0) - continue; - - auto coreIdsAttr = batchOp->getAttrOfType(onnx_mlir::kCoreIdsAttrName); - if (!coreIdsAttr) - return batchOp.emitOpError("missing coreIds while materializing batch result communication"); - - FailureOr> batchYieldInfo = collectBatchYieldInfo(batchOp); - if (failed(batchYieldInfo)) - return batchOp.emitOpError("failed to collect per-result yielded values from compute_batch body"); - - FailureOr newBatch = cloneBatchAsResultless(batchOp, rewriter); - if (failed(newBatch)) - return batchOp.emitOpError("failed to clone resultful compute_batch as resultless"); - - Block& oldBlock = batchOp.getBody().front(); - Block& newBlock = newBatch->getBody().front(); - IRMapping mapper; - mapper.map(batchOp.getLaneArgument(), newBatch->getLaneArgument()); - for (unsigned weightIndex = 0; weightIndex < batchOp.getWeights().size(); ++weightIndex) - mapper.map(batchOp.getWeightArgument(weightIndex), newBatch->getWeightArgument(weightIndex)); - for (unsigned inputIndex = 0; inputIndex < batchOp.getInputs().size(); ++inputIndex) - mapper.map(batchOp.getInputArgument(inputIndex), newBatch->getInputArgument(inputIndex)); - auto oldIt = oldBlock.begin(); - auto newIt = newBlock.begin(); - for (; oldIt != oldBlock.end() && newIt != newBlock.end(); ++oldIt, ++newIt) - for (auto [oldResult, newResult] : llvm::zip(oldIt->getResults(), newIt->getResults())) - mapper.map(oldResult, newResult); - - SmallVector sourceCoreIds(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); - rewriter.setInsertionPointToEnd(&newBlock); - - for (unsigned resultIndex = 0; resultIndex < batchOp.getNumResults(); ++resultIndex) { - BlockArgument outputArg = batchOp.getOutputArgument(resultIndex); - auto yieldInfoIt = batchYieldInfo->find(outputArg); - if (yieldInfoIt == batchYieldInfo->end()) - return batchOp.emitOpError( - "missing yielded value for compute_batch result during communication materialization"); - Value mappedYieldedValue = mapper.lookup(yieldInfoIt->second.yieldedValue); - - DenseMap> computeUsesByTargetCore; - SmallVector hostUses; - for (OpOperand& use : batchOp.getResult(resultIndex).getUses()) { - if (auto computeOp = dyn_cast(use.getOwner())) { - auto coreIdAttr = computeOp->getAttrOfType(onnx_mlir::kCoreIdAttrName); - if (!coreIdAttr) - return batchOp.emitOpError("compute user of compute_batch result is missing coreId"); - computeUsesByTargetCore[static_cast(coreIdAttr.getInt())].push_back(&use); - continue; - } - if (isHostOnlyBatchResultUser(use.getOwner())) { - hostUses.push_back(&use); - continue; - } - return batchOp.emitOpError("unsupported user of compute_batch result during communication materialization") - << ": " << use.getOwner()->getName(); - } - - auto createReceiveForUses = [&](ArrayRef uses, ArrayRef targetCoreIds) -> LogicalResult { - if (uses.empty()) - return success(); - - SmallVector channelIds; - channelIds.reserve(sourceCoreIds.size()); - for ([[maybe_unused]] int32_t sourceCoreId : sourceCoreIds) - channelIds.push_back(nextChannelId++); - SmallVector sendChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder); - SmallVector sendSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder); - SmallVector sendTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder); - - spatial::SpatChannelSendBatchOp::create(rewriter, - batchOp.getLoc(), - sendChannelIdValues, - sendSourceCoreIdValues, - sendTargetCoreIdValues, - mappedYieldedValue); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(newBatch->getOperation()); - SmallVector receiveChannelIdValues = createIndexConstants(batchOp, channelIds, constantFolder); - SmallVector receiveSourceCoreIdValues = createIndexConstants(batchOp, sourceCoreIds, constantFolder); - SmallVector receiveTargetCoreIdValues = createIndexConstants(batchOp, targetCoreIds, constantFolder); - auto received = spatial::SpatChannelReceiveTensorOp::create(rewriter, - batchOp.getLoc(), - batchOp.getResult(resultIndex).getType(), - receiveChannelIdValues, - receiveSourceCoreIdValues, - receiveTargetCoreIdValues); - for (OpOperand* use : uses) - use->set(received.getOutput()); - rewriter.setInsertionPointToEnd(&newBlock); - return success(); - }; - - for (auto& [targetCoreId, uses] : computeUsesByTargetCore) { - SmallVector targetCoreIds(static_cast(batchOp.getLaneCount()), targetCoreId); - if (failed(createReceiveForUses(uses, targetCoreIds))) - return failure(); - } - - if (!hostUses.empty()) { - SmallVector hostTargetCoreIds(static_cast(batchOp.getLaneCount()), 0); - if (failed(createReceiveForUses(hostUses, hostTargetCoreIds))) - return failure(); - } - } - - rewriter.setInsertionPointToEnd(&newBlock); - spatial::SpatYieldOp::create(rewriter, batchOp.getLoc(), ValueRange {}); - rewriter.eraseOp(batchOp); - } - - return success(); -} - void rebatchEquivalentComputes(func::FuncOp funcOp) { IRRewriter rewriter(funcOp.getContext()); OperationFolder constantFolder(funcOp.getContext()); @@ -731,11 +525,6 @@ LogicalResult runPostMergeCompactionPipeline(func::FuncOp funcOp, int64_t& nextC ScopedMergePhaseTimer timer("cleanup-dead-packing-ops"); cleanupDeadPackingOps(funcOp); } - { - ScopedMergePhaseTimer timer("materialize-batch-result-communication"); - if (failed(materializeBatchResultCommunication(funcOp, nextChannelId))) - return failure(); - } return success(); } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp index b62f1e2..fe0d639 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/Scheduling/PeftScheduler.cpp @@ -1,7 +1,6 @@ #include "mlir/IR/Threading.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" @@ -20,7 +19,6 @@ struct ScheduledTask { size_t processor = std::numeric_limits::max(); Time startTime = 0; Time endTime = 0; - size_t slot = 0; }; std::vector> buildReverseLevels(const ComputeGraph& graph) { @@ -244,7 +242,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu llvm::report_fatal_error(llvm::StringRef(message)); } - schedules[task] = {bestProcessor, bestEst, bestEft, 0}; + schedules[task] = {bestProcessor, bestEst, bestEft}; scheduled[task] = true; ++scheduledCount; processorCrossbars[bestProcessor] = addOrMax(processorCrossbars[bestProcessor], graph.nodes[task].crossbarUsage); @@ -278,7 +276,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu // 5. Check if equal schedule in two level llvm::DenseMap> equivalentClass; for (size_t currentProcessor = 0; currentProcessor < processorCount - 1; ++currentProcessor) { - for (size_t controlProcessor = currentProcessor + 1; controlProcessor < processorCount; ++controlProcessor) { + for (size_t controlProcessor = currentProcessor; controlProcessor < processorCount; ++controlProcessor) { if (tasksByProcessor[currentProcessor].size() != tasksByProcessor[controlProcessor].size()) continue; auto& currentTasks = tasksByProcessor[currentProcessor]; @@ -288,7 +286,8 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu for (auto [currentTask, controlTask] : llvm::zip(currentTasks, controlTasks)) { const ComputeInstance currentComputeInstance = graph.nodes[currentTask].instance; const ComputeInstance controlComputeInstance = graph.nodes[controlTask].instance; - if (currentComputeInstance.op != controlComputeInstance.op) { + if (currentComputeInstance.op != controlComputeInstance.op + || currentComputeInstance.laneCount != controlComputeInstance.laneCount) { equalSchedule = false; break; } @@ -300,11 +299,11 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu } } } -{ + /*{ llvm::dbgs() << "--- Scheduling Equivalence Classes ---\n"; std::vector visited(processorCount, false); size_t uniqueClassCount = 0; - + for (size_t i = 0; i < processorCount; ++i) { if (visited[i]) continue; @@ -312,7 +311,7 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu // We found a new unique schedule (equivalence class) ++uniqueClassCount; visited[i] = true; - + llvm::dbgs() << "Class " << uniqueClassCount << ": CPUs { " << i; // Find and mark all identical companions @@ -327,10 +326,10 @@ MergeScheduleResult runPeftScheduler(const ComputeGraph& graph, const PeftSchedu } llvm::dbgs() << " }\n"; } - + llvm::dbgs() << "Total unique CPU nodes to emit: " << uniqueClassCount << "\n"; llvm::dbgs() << "--------------------------------------\n"; - } + }*/ // 6. Populate Final Result MergeScheduleResult result;