From c734f1b37e7dbb73281534e14f544ac5f9b0a9e6 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Sun, 24 May 2026 10:10:24 +0200 Subject: [PATCH] better MaterializeMergeSchedule.cpp that emits much more compact IR add support for other constant-time arith ops in codegen --- src/PIM/Common/IR/AddressAnalysis.cpp | 64 + src/PIM/Common/IR/CoreBlockUtils.cpp | 35 +- src/PIM/Compiler/PimCodeGen.cpp | 12 + src/PIM/Compiler/PimCodeGen.hpp | 1 + .../MaterializeMergeSchedule.cpp | 2379 ++++++++++++++--- .../MaterializeHostConstantsPass.cpp | 4 +- 6 files changed, 2067 insertions(+), 428 deletions(-) diff --git a/src/PIM/Common/IR/AddressAnalysis.cpp b/src/PIM/Common/IR/AddressAnalysis.cpp index a8daf58..3019749 100644 --- a/src/PIM/Common/IR/AddressAnalysis.cpp +++ b/src/PIM/Common/IR/AddressAnalysis.cpp @@ -4,6 +4,8 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include + #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" @@ -99,6 +101,33 @@ static llvm::FailureOr resolveConstantGlobalLoad(mlir::memref::LoadOp l return denseAttr.getValues()[linearIndex].getSExtValue(); } +static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) { + switch (predicate) { + case mlir::arith::CmpIPredicate::eq: + return lhs == rhs; + case mlir::arith::CmpIPredicate::ne: + return lhs != rhs; + case mlir::arith::CmpIPredicate::slt: + return lhs < rhs; + case mlir::arith::CmpIPredicate::sle: + return lhs <= rhs; + case mlir::arith::CmpIPredicate::sgt: + return lhs > rhs; + case mlir::arith::CmpIPredicate::sge: + return lhs >= rhs; + case mlir::arith::CmpIPredicate::ult: + return static_cast(lhs) < static_cast(rhs); + case mlir::arith::CmpIPredicate::ule: + return static_cast(lhs) <= static_cast(rhs); + case mlir::arith::CmpIPredicate::ugt: + return static_cast(lhs) > static_cast(rhs); + case mlir::arith::CmpIPredicate::uge: + return static_cast(lhs) >= static_cast(rhs); + } + + llvm_unreachable("unknown cmpi predicate"); +} + llvm::FailureOr resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { value = resolveAlias(value, knowledge); @@ -153,6 +182,16 @@ llvm::FailureOr resolveIndexValueImpl(mlir::Value value, const StaticVa return static_cast(static_cast(*lhs) / static_cast(*rhs)); } + if (auto divOp = mlir::dyn_cast(definingOp)) { + auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge); + auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge); + if (failed(lhs) || failed(rhs) || *rhs == 0) + return mlir::failure(); + if (*lhs == std::numeric_limits::min() && *rhs == -1) + return mlir::failure(); + return *lhs / *rhs; + } + if (auto minOp = mlir::dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge); @@ -169,6 +208,31 @@ llvm::FailureOr resolveIndexValueImpl(mlir::Value value, const StaticVa return static_cast(static_cast(*lhs) % static_cast(*rhs)); } + if (auto remOp = mlir::dyn_cast(definingOp)) { + auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge); + auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge); + if (failed(lhs) || failed(rhs) || *rhs == 0) + return mlir::failure(); + if (*lhs == std::numeric_limits::min() && *rhs == -1) + return 0; + return *lhs % *rhs; + } + + if (auto cmpOp = mlir::dyn_cast(definingOp)) { + auto lhs = resolveIndexValueImpl(cmpOp.getLhs(), knowledge); + auto rhs = resolveIndexValueImpl(cmpOp.getRhs(), knowledge); + if (failed(lhs) || failed(rhs)) + return mlir::failure(); + return evaluateCmpPredicate(cmpOp.getPredicate(), *lhs, *rhs) ? 1 : 0; + } + + if (auto selectOp = mlir::dyn_cast(definingOp)) { + auto condition = resolveIndexValueImpl(selectOp.getCondition(), knowledge); + if (failed(condition)) + return mlir::failure(); + return resolveIndexValueImpl(*condition != 0 ? selectOp.getTrueValue() : selectOp.getFalseValue(), knowledge); + } + if (auto loadOp = mlir::dyn_cast(definingOp)) return resolveConstantGlobalLoad(loadOp, knowledge); diff --git a/src/PIM/Common/IR/CoreBlockUtils.cpp b/src/PIM/Common/IR/CoreBlockUtils.cpp index be78ba1..6c104ac 100644 --- a/src/PIM/Common/IR/CoreBlockUtils.cpp +++ b/src/PIM/Common/IR/CoreBlockUtils.cpp @@ -8,19 +8,28 @@ namespace onnx_mlir { bool isCoreStaticAddressOp(mlir::Operation* op) { - return mlir::isa(op); + if (mlir::isa(op)) + return true; + + if (auto selectOp = mlir::dyn_cast(op)) + return selectOp.getType().isIntOrIndex(); + + return false; } mlir::LogicalResult diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 2e26140..3897e52 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -405,6 +405,16 @@ void PimCodeGen::codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticVa loadOp.getSize()); } +void PimCodeGen::codeGenLoadBatchOp(pim::PimMemCopyHostToDevBatchOp loadOp, + const StaticValueKnowledge& knowledge) const { + emitMemCopyOp("ld", + addressOf(loadOp.getDeviceTarget(), knowledge), + loadOp.getDeviceTargetOffset(), + addressOf(loadOp.getHostSource(), knowledge), + loadOp.getHostSourceOffset(), + loadOp.getSize()); +} + void PimCodeGen::codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const { auto hostTargetOffset = resolveIndexValue(storeOp.getHostTargetOffset(), knowledge); auto deviceSourceOffset = resolveIndexValue(storeOp.getDeviceSourceOffset(), knowledge); @@ -825,6 +835,8 @@ static int64_t codeGenCoreOps(Block& block, PimCodeGen& coreCodeGen) { walkPimCoreBlock(block, StaticValueKnowledge {}, [&](Operation& op, const StaticValueKnowledge& knowledge) { if (auto loadOp = dyn_cast(op)) coreCodeGen.codeGenLoadOp(loadOp, knowledge); + else if (auto loadBatchOp = dyn_cast(op)) + coreCodeGen.codeGenLoadBatchOp(loadBatchOp, knowledge); else if (auto storeOp = dyn_cast(op)) coreCodeGen.codeGenStoreOp(storeOp, knowledge); else if (auto lmvOp = dyn_cast(op)) diff --git a/src/PIM/Compiler/PimCodeGen.hpp b/src/PIM/Compiler/PimCodeGen.hpp index b792487..bbe45dc 100644 --- a/src/PIM/Compiler/PimCodeGen.hpp +++ b/src/PIM/Compiler/PimCodeGen.hpp @@ -143,6 +143,7 @@ public: uint32_t getEmittedInstructionCount() const { return emittedInstructionCount; } void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const; + void codeGenLoadBatchOp(pim::PimMemCopyHostToDevBatchOp loadOp, const StaticValueKnowledge& knowledge) const; void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const; void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const; diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 752eaf0..59b1ada 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -13,14 +13,11 @@ #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/FormatVariadic.h" #include #include #include #include -#include #include "MaterializeMergeSchedule.hpp" #include "Scheduling/ComputeInstanceUtils.hpp" @@ -63,45 +60,8 @@ struct ProducerKeyInfo { static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; } }; -struct CpuSlotKey { - CpuId cpu = 0; - SlotId slot = 0; - - bool operator==(const CpuSlotKey& other) const { return cpu == other.cpu && slot == other.slot; } -}; - -struct CpuSlotKeyInfo { - static CpuSlotKey getEmptyKey() { return {std::numeric_limits::max(), std::numeric_limits::max()}; } - - static CpuSlotKey getTombstoneKey() { - return {std::numeric_limits::max() - 1, std::numeric_limits::max()}; - } - - static unsigned getHashValue(const CpuSlotKey& key) { return llvm::hash_combine(key.cpu, key.slot); } - - static bool isEqual(const CpuSlotKey& lhs, const CpuSlotKey& rhs) { return lhs == rhs; } -}; - -struct ClassSlotKey { - ClassId classId = 0; - SlotId slot = 0; - - bool operator==(const ClassSlotKey& other) const { return classId == other.classId && slot == other.slot; } -}; - -struct ClassSlotKeyInfo { - static ClassSlotKey getEmptyKey() { - return {std::numeric_limits::max(), std::numeric_limits::max()}; - } - - static ClassSlotKey getTombstoneKey() { - return {std::numeric_limits::max() - 1, std::numeric_limits::max()}; - } - - static unsigned getHashValue(const ClassSlotKey& key) { return llvm::hash_combine(key.classId, key.slot); } - - static bool isEqual(const ClassSlotKey& lhs, const ClassSlotKey& rhs) { return lhs == rhs; } -}; +using CpuSlotKey = std::pair; +using ClassSlotKey = std::pair; struct MaterializedClass { ClassId id = 0; @@ -119,6 +79,72 @@ struct MaterializedClass { DenseMap hostOutputToResultIndex; }; +struct PackedScalarRunSlot { + SmallVector keys; +}; + +enum class PackedScalarRunKind { + Materialized, + DeferredReceive, + DeferredLocalCompute +}; + +struct PackedScalarRunValue { + ClassId targetClass = 0; + Operation* sourceOp = nullptr; + size_t resultIndex = 0; + PackedScalarRunKind kind = PackedScalarRunKind::Materialized; + + Value packed; + + RankedTensorType fragmentType; + SmallVector slots; + + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; +}; + +struct MaterializationRunSlot { + SmallVector peers; +}; + +using MaterializationRun = SmallVector; + +struct OutputDestinationGroup { + SmallVector resultIndices; + SmallVector destinationClasses; +}; + +struct BatchRunSendPlan { + size_t resultIndex = 0; + ClassId destinationClass = 0; + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; +}; + +struct MaterializerState; + +class AvailableValueStore { +public: + void record(ProducerKey key, ClassId classId, Value value) { exactValues[key][classId] = value; } + + void recordPackedRun(PackedScalarRunValue run) { packedScalarRuns.push_back(std::move(run)); } + + std::optional lookupExact(ProducerKey key, ClassId classId) const; + + std::optional lookup(MaterializerState& state, ProducerKey key, ClassId classId); + + SmallVectorImpl& getPackedScalarRuns() { return packedScalarRuns; } + +private: + std::optional lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId); + + DenseMap, ProducerKeyInfo> exactValues; + SmallVector packedScalarRuns; +}; + struct MaterializerState { func::FuncOp func; const MergeScheduleResult& schedule; @@ -128,11 +154,11 @@ struct MaterializerState { SmallVector classes; DenseMap cpuToClass; - DenseMap cpuSlotToInstance; - DenseSet materializedSlots; + DenseMap cpuSlotToInstance; + DenseSet materializedSlots; DenseMap, ProducerKeyInfo> producerDestClasses; - DenseMap, ProducerKeyInfo> availableValues; + AvailableValueStore availableValues; DenseMap hostReplacements; DenseSet oldComputeOps; @@ -251,16 +277,10 @@ FailureOr getPackedBatchTensorType(Type laneType, size_t laneC return RankedTensorType::get(shape, tensorType.getElementType(), tensorType.getEncoding()); } -std::optional lookupAvailableValue(MaterializerState& state, ProducerKey key, ClassId classId) { - auto producerIt = state.availableValues.find(key); - if (producerIt == state.availableValues.end()) - return std::nullopt; - - auto valueIt = producerIt->second.find(classId); - if (valueIt == producerIt->second.end()) - return std::nullopt; - - return valueIt->second; +LogicalResult verifyPackableFragmentType(Operation* anchor, Type fragmentType, size_t count, StringRef message) { + if (failed(getPackedBatchTensorType(fragmentType, count))) + return anchor->emitError(message); + return success(); } std::optional getProducerKey(Value value, const ComputeInstance* consumerInstance = nullptr) { @@ -558,6 +578,315 @@ Value createIndexConstant(MaterializerState& state, Operation* anchor, int64_t v return getOrCreateHostIndexConstant(anchor, value, state.constantFolder); } +// ----------------------------------------------------------------------------- +// Tensor packing helpers. +// ----------------------------------------------------------------------------- + +struct Dim0SliceParams { + SmallVector offsets; + SmallVector sizes; + SmallVector strides; +}; + +Dim0SliceParams +buildDim0SliceParams(OpBuilder& builder, RankedTensorType referenceType, OpFoldResult firstOffset, int64_t firstSize) { + Dim0SliceParams params; + params.offsets.reserve(referenceType.getRank()); + params.sizes.reserve(referenceType.getRank()); + params.strides.reserve(referenceType.getRank()); + + params.offsets.push_back(firstOffset); + params.sizes.push_back(builder.getIndexAttr(firstSize)); + params.strides.push_back(builder.getIndexAttr(1)); + + for (int64_t dim = 1; dim < referenceType.getRank(); ++dim) { + params.offsets.push_back(builder.getIndexAttr(0)); + params.sizes.push_back(builder.getIndexAttr(referenceType.getDimSize(dim))); + params.strides.push_back(builder.getIndexAttr(1)); + } + + return params; +} + +Value createDim0ExtractSlice( + MaterializerState& state, Location loc, Value source, OpFoldResult firstOffset, int64_t firstSize) { + auto sourceType = cast(source.getType()); + Dim0SliceParams params = buildDim0SliceParams(state.rewriter, sourceType, firstOffset, firstSize); + return tensor::ExtractSliceOp::create(state.rewriter, loc, source, params.offsets, params.sizes, params.strides) + .getResult(); +} + +Value createDim0InsertSlice( + MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) { + auto fragmentType = cast(fragment.getType()); + Dim0SliceParams params = buildDim0SliceParams(state.rewriter, fragmentType, firstOffset, fragmentType.getDimSize(0)); + return tensor::InsertSliceOp::create( + state.rewriter, loc, fragment, destination, params.offsets, params.sizes, params.strides) + .getResult(); +} + +void createDim0ParallelInsertSlice( + MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) { + auto fragmentType = cast(fragment.getType()); + Dim0SliceParams params = buildDim0SliceParams(state.rewriter, fragmentType, firstOffset, fragmentType.getDimSize(0)); + tensor::ParallelInsertSliceOp::create( + state.rewriter, loc, fragment, destination, params.offsets, params.sizes, params.strides); +} + +Value scaleIndexByDim0Size(MaterializerState& state, Operation* anchor, Value index, int64_t dim0Size, Location loc) { + if (dim0Size == 1) + return index; + + Value dim0SizeValue = createIndexConstant(state, anchor, dim0Size); + return arith::MulIOp::create(state.rewriter, loc, index, dim0SizeValue).getResult(); +} + +bool sameProducerResult(ProducerKey lhs, ProducerKey rhs) { + return lhs.instance.op == rhs.instance.op && lhs.resultIndex == rhs.resultIndex; +} + +bool containsProducerKey(ProducerKey outer, ProducerKey inner) { + if (!sameProducerResult(outer, inner)) + return false; + if (!isa(outer.instance.op)) + return false; + if (outer.instance.laneCount == 0 || inner.instance.laneCount == 0) + return false; + + uint32_t outerStart = outer.instance.laneStart; + uint32_t outerEnd = outerStart + outer.instance.laneCount; + uint32_t innerStart = inner.instance.laneStart; + uint32_t innerEnd = innerStart + inner.instance.laneCount; + + return outerStart <= innerStart && innerEnd <= outerEnd; +} + +std::optional extractPackedProducerSlice(MaterializerState& state, + MaterializedClass& materializedClass, + ProducerKey packedKey, + Value packed, + ProducerKey requestedKey) { + if (!containsProducerKey(packedKey, requestedKey)) + return std::nullopt; + + auto packedType = dyn_cast(packed.getType()); + if (!packedType || !packedType.hasStaticShape() || packedType.getRank() == 0) + return std::nullopt; + + if (packedKey.instance.laneCount == 0) + return std::nullopt; + + int64_t packedRows = packedType.getDimSize(0); + if (packedRows % static_cast(packedKey.instance.laneCount) != 0) + return std::nullopt; + + int64_t rowsPerLane = packedRows / static_cast(packedKey.instance.laneCount); + int64_t rowOffset = + static_cast(requestedKey.instance.laneStart - packedKey.instance.laneStart) * rowsPerLane; + int64_t rowCount = static_cast(requestedKey.instance.laneCount) * rowsPerLane; + + state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); + + Value firstOffset = createIndexConstant(state, materializedClass.op, rowOffset); + return createDim0ExtractSlice(state, materializedClass.op->getLoc(), packed, firstOffset, rowCount); +} + +std::optional AvailableValueStore::lookupExact(ProducerKey key, ClassId classId) const { + auto producerIt = exactValues.find(key); + if (producerIt == exactValues.end()) + return std::nullopt; + + auto valueIt = producerIt->second.find(classId); + if (valueIt == producerIt->second.end()) + return std::nullopt; + + return valueIt->second; +} + +Value getPackedSliceForRunIndex(MaterializerState& state, + Operation* anchor, + Value packed, + RankedTensorType fragmentType, + size_t index, + Location loc) { + int64_t rowOffset = static_cast(index) * fragmentType.getDimSize(0); + Value firstOffset = createIndexConstant(state, anchor, rowOffset); + return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0)); +} + +SmallVector widenToI64(ArrayRef values) { + SmallVector widened; + widened.reserve(values.size()); + for (int32_t value : values) + widened.push_back(value); + return widened; +} + +Value createReceiveConcatLoop(MaterializerState& state, + Operation* anchor, + Operation* insertionPoint, + RankedTensorType concatType, + RankedTensorType fragmentType, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc); + +FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& state, + MaterializedClass& targetClass, + PackedScalarRunValue& run, + Location loc); + +bool isDeferredLocalPackedScalarRun(const PackedScalarRunValue& run) { + return run.kind == PackedScalarRunKind::DeferredLocalCompute; +} + +size_t getPackedScalarRunReceiveCount(const PackedScalarRunValue& run) { + size_t count = 0; + for (const PackedScalarRunSlot& slot : run.slots) + count += slot.keys.size(); + return count; +} + +LogicalResult validatePackedScalarRunMetadata(Operation* anchor, const PackedScalarRunValue& run) { + if (run.kind == PackedScalarRunKind::DeferredLocalCompute) + return success(); + + size_t receiveCount = getPackedScalarRunReceiveCount(run); + + if (receiveCount == 0) + return anchor->emitError("packed scalar run has no receives"); + + if (run.channelIds.size() != receiveCount || run.sourceCoreIds.size() != receiveCount + || run.targetCoreIds.size() != receiveCount) + return anchor->emitError("packed scalar run receive metadata is inconsistent"); + + return success(); +} + +FailureOr materializePackedScalarRunValue(MaterializerState& state, + MaterializedClass& targetClass, + PackedScalarRunValue& run, + Location loc) { + if (run.packed) + return run.packed; + + if (run.kind == PackedScalarRunKind::Materialized) + return targetClass.op->emitError("materialized packed scalar run has no packed value"); + + if (isDeferredLocalPackedScalarRun(run)) + return materializeDeferredLocalPackedScalarRunValue(state, targetClass, run, loc); + + if (failed(validatePackedScalarRunMetadata(targetClass.op, run))) + return failure(); + + FailureOr fullPackedType = + getPackedBatchTensorType(run.fragmentType, getPackedScalarRunReceiveCount(run)); + if (failed(fullPackedType)) + return targetClass.op->emitError("cannot create lazy packed scalar run receive type"); + + SmallVector sourceCoreIds = widenToI64(run.sourceCoreIds); + SmallVector targetCoreIds = widenToI64(run.targetCoreIds); + + run.packed = createReceiveConcatLoop(state, + targetClass.op, + targetClass.body->getTerminator(), + *fullPackedType, + run.fragmentType, + run.channelIds, + sourceCoreIds, + targetCoreIds, + loc); + return run.packed; +} + +std::optional AvailableValueStore::lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId) { + for (PackedScalarRunValue& run : packedScalarRuns) { + if (run.targetClass != classId) + continue; + if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) + continue; + + for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) { + std::optional slotKey = getContiguousProducerKeyForKeys(slot.keys); + if (!slotKey || !containsProducerKey(*slotKey, key)) + continue; + + FailureOr slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size()); + if (failed(slotPackedType)) + return std::nullopt; + + MaterializedClass& materializedClass = state.classes[classId]; + state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); + + FailureOr packed = + materializePackedScalarRunValue(state, materializedClass, run, materializedClass.op->getLoc()); + if (failed(packed)) + return std::nullopt; + + Value slotPacked = + getPackedSliceForRunIndex(state, materializedClass.op, *packed, *slotPackedType, slotIndex, (*packed).getLoc()); + + if (*slotKey == key) { + record(key, classId, slotPacked); + return slotPacked; + } + + std::optional sliced = extractPackedProducerSlice(state, materializedClass, *slotKey, slotPacked, key); + if (!sliced) + return std::nullopt; + + record(key, classId, *sliced); + return *sliced; + } + } + + return std::nullopt; +} + +std::optional AvailableValueStore::lookup(MaterializerState& state, ProducerKey key, ClassId classId) { + if (std::optional exact = lookupExact(key, classId)) + return exact; + + if (std::optional packedRunValue = lookupPackedRun(state, key, classId)) + return packedRunValue; + + if (key.instance.laneCount != 1) + return std::nullopt; + + MaterializedClass& materializedClass = state.classes[classId]; + + ProducerKey containingKey; + Value containingValue; + bool foundContainingValue = false; + + for (auto& entry : exactValues) { + ProducerKey candidateKey = entry.first; + if (!containsProducerKey(candidateKey, key)) + continue; + + auto valueIt = entry.second.find(classId); + if (valueIt == entry.second.end()) + continue; + + containingKey = candidateKey; + containingValue = valueIt->second; + foundContainingValue = true; + break; + } + + if (!foundContainingValue) + return std::nullopt; + + std::optional slice = + extractPackedProducerSlice(state, materializedClass, containingKey, containingValue, key); + if (!slice) + return std::nullopt; + + record(key, classId, *slice); + return *slice; +} + Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef values) { SmallVector elements; elements.reserve(values.size()); @@ -569,14 +898,6 @@ Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, Arr return getOrCreateHostConstant(anchor, attr, type, state.constantFolder); } -Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef values) { - SmallVector widened; - widened.reserve(values.size()); - for (int32_t value : values) - widened.push_back(value); - return createIndexTensorConstant(state, anchor, ArrayRef(widened)); -} - bool allEqual(ArrayRef values) { assert(!values.empty() && "expected at least one value"); for (int64_t value : values.drop_front()) @@ -585,14 +906,6 @@ bool allEqual(ArrayRef values) { return true; } -bool allEqual(ArrayRef values) { - assert(!values.empty() && "expected at least one value"); - for (int32_t value : values.drop_front()) - if (value != values.front()) - return false; - return true; -} - struct IndexedIndexPattern { int64_t base = 0; int64_t step = 0; @@ -873,6 +1186,10 @@ ArrayRef getDestinationClasses(MaterializerState& state, ProducerKey ke return it->second; } +// ----------------------------------------------------------------------------- +// Communication materialization helpers. +// ----------------------------------------------------------------------------- + void appendScalarSend(MaterializerState& state, MaterializedClass& sourceClass, Value payload, @@ -992,242 +1309,175 @@ Value appendReceive(MaterializerState& state, state, targetClass, type, channelIds.front(), sourceCoreIds.front(), targetCoreIds.front(), loc); } -Value appendPackedScalarReceives(MaterializerState& state, - MaterializedClass& targetClass, - Type fragmentType, - Type packedType, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - Location loc) { - assert(!targetClass.isBatch && "packed scalar receive helper expects a scalar target class"); - assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); - assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); - assert(!channelIds.empty() && "expected at least one receive"); +LogicalResult registerLazyPackedScalarReceives(MaterializerState& state, + MaterializedClass& sourceClass, + MaterializedClass& targetClass, + ArrayRef keys, + Type fragmentType, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds) { + if (!sourceClass.isBatch) + return sourceClass.op->emitError("lazy packed scalar receives expect a batch source class"); - SmallVector fragments; - fragments.reserve(channelIds.size()); - for (auto index : llvm::seq(0, channelIds.size())) { - fragments.push_back(appendScalarReceive( - state, targetClass, fragmentType, channelIds[index], sourceCoreIds[index], targetCoreIds[index], loc)); + if (targetClass.isBatch) + return targetClass.op->emitError("lazy packed scalar receives expect a scalar target class"); + + if (keys.empty()) + return sourceClass.op->emitError("lazy packed scalar receive expects at least one producer key"); + + if (keys.size() != sourceClass.cpus.size()) + return sourceClass.op->emitError("lazy packed scalar receive expects one producer key per source lane"); + + if (keys.size() != channelIds.size() || keys.size() != sourceCoreIds.size() || keys.size() != targetCoreIds.size()) + return targetClass.op->emitError("lazy packed scalar receive metadata is inconsistent"); + + auto rankedFragmentType = dyn_cast(fragmentType); + if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) + return targetClass.op->emitError("lazy packed scalar receive expects a static ranked fragment type"); + + if (failed(verifyPackableFragmentType( + targetClass.op, fragmentType, keys.size(), "cannot create lazy packed scalar receive type"))) + return failure(); + + Operation* sourceOp = keys.front().instance.op; + size_t resultIndex = keys.front().resultIndex; + + for (ProducerKey key : keys) { + if (key.instance.op != sourceOp || key.resultIndex != resultIndex) + return sourceClass.op->emitError("lazy packed scalar receive expects one producer result"); + + if (key.instance.laneCount != 1) + return sourceClass.op->emitError("lazy packed scalar receive expects one lane per producer key"); } - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + PackedScalarRunValue packedRun; + packedRun.targetClass = targetClass.id; + packedRun.sourceOp = sourceOp; + packedRun.resultIndex = resultIndex; + packedRun.kind = PackedScalarRunKind::DeferredReceive; + packedRun.fragmentType = rankedFragmentType; - Value packed = fragments.front(); - if (fragments.size() != 1) - packed = tensor::ConcatOp::create(state.rewriter, loc, 0, ValueRange(fragments)).getResult(); + llvm::append_range(packedRun.channelIds, channelIds); + llvm::append_range(packedRun.sourceCoreIds, sourceCoreIds); + llvm::append_range(packedRun.targetCoreIds, targetCoreIds); - if (packed.getType() != packedType) - packed = tensor::CastOp::create(state.rewriter, loc, packedType, packed).getResult(); + PackedScalarRunSlot slot; + llvm::append_range(slot.keys, keys); + packedRun.slots.push_back(std::move(slot)); - return packed; + if (failed(validatePackedScalarRunMetadata(targetClass.op, packedRun))) + return failure(); + + state.availableValues.recordPackedRun(std::move(packedRun)); + return success(); } -std::optional getConstantIndexValue(Value value) { - if (auto constant = value.getDefiningOp()) - return constant.value(); - - APInt constantValue; - if (matchPattern(value, m_ConstantInt(&constantValue))) - return constantValue.getSExtValue(); - - return std::nullopt; -} - -bool getReceiveMetadata(SpatChannelReceiveOp receive, - int64_t& channelId, - int64_t& sourceCoreId, - int64_t& targetCoreId) { - // SpatChannelReceiveOp operands are: channel, source core, target core. - std::optional channel = getConstantIndexValue(receive->getOperand(0)); - std::optional source = getConstantIndexValue(receive->getOperand(1)); - std::optional target = getConstantIndexValue(receive->getOperand(2)); - if (!channel || !source || !target) - return false; - - channelId = *channel; - sourceCoreId = *source; - targetCoreId = *target; - return true; -} - -bool hasCompatibleConcatTypes(RankedTensorType concatType, RankedTensorType fragmentType, size_t fragmentCount) { - if (!concatType.hasStaticShape() || !fragmentType.hasStaticShape()) - return false; - if (concatType.getRank() != fragmentType.getRank()) - return false; - if (concatType.getRank() == 0) - return false; - if (concatType.getElementType() != fragmentType.getElementType()) - return false; - - if (concatType.getDimSize(0) != fragmentType.getDimSize(0) * static_cast(fragmentCount)) - return false; - - for (int64_t dim = 1; dim < concatType.getRank(); ++dim) - if (concatType.getDimSize(dim) != fragmentType.getDimSize(dim)) - return false; - - return true; -} - -Value createReceiveConcatLoop(MaterializerState& state, - Operation* anchor, - Operation* insertionPoint, - RankedTensorType concatType, - RankedTensorType fragmentType, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - Location loc) { - assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); - assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); - assert(!channelIds.empty() && "expected at least one receive"); - - Value lowerBound = createIndexConstant(state, anchor, 0); - Value upperBound = createIndexConstant(state, anchor, static_cast(channelIds.size())); - Value step = createIndexConstant(state, anchor, 1); - - state.rewriter.setInsertionPoint(insertionPoint); - Value init = - tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult(); - auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init}); - - Block* body = loop.getBody(); - if (!body->empty()) - if (auto yield = dyn_cast(body->back())) - state.rewriter.eraseOp(yield); - - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPointToEnd(body); - - Value index = loop.getInductionVar(); - Value acc = body->getArgument(1); - - Value channelId = createIndexedIndexValue(state, anchor, channelIds, index, loc); - Value sourceCoreId = createIndexedIndexValue(state, anchor, sourceCoreIds, index, loc); - Value targetCoreId = createIndexedIndexValue(state, anchor, targetCoreIds, index, loc); - - Value received = - SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId).getOutput(); - - Value firstOffset = index; - if (fragmentType.getDimSize(0) != 1) { - Value rowsPerFragment = createIndexConstant(state, anchor, fragmentType.getDimSize(0)); - firstOffset = arith::MulIOp::create(state.rewriter, loc, index, rowsPerFragment).getResult(); - } - - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(fragmentType.getRank()); - sizes.reserve(fragmentType.getRank()); - strides.reserve(fragmentType.getRank()); - - offsets.push_back(firstOffset); - sizes.push_back(state.rewriter.getIndexAttr(fragmentType.getDimSize(0))); - strides.push_back(state.rewriter.getIndexAttr(1)); - - for (int64_t dim = 1; dim < fragmentType.getRank(); ++dim) { - offsets.push_back(state.rewriter.getIndexAttr(0)); - sizes.push_back(state.rewriter.getIndexAttr(fragmentType.getDimSize(dim))); - strides.push_back(state.rewriter.getIndexAttr(1)); - } - - Value next = tensor::InsertSliceOp::create(state.rewriter, loc, received, acc, offsets, sizes, strides).getResult(); - scf::YieldOp::create(state.rewriter, loc, next); - - return loop.getResult(0); -} - -bool compactReceiveConcat(MaterializerState& state, MaterializedClass& materializedClass, tensor::ConcatOp concat) { - auto dimAttr = concat->getAttrOfType("dim"); - if (!dimAttr || dimAttr.getInt() != 0) - return false; - - OperandRange inputs = concat->getOperands(); - if (inputs.size() < 2) - return false; - - SmallVector receives; - receives.reserve(inputs.size()); - - for (Value input : inputs) { - auto receive = input.getDefiningOp(); - if (!receive) - return false; - if (receive->getBlock() != concat->getBlock()) - return false; - if (!receive->getResult(0).hasOneUse()) - return false; - receives.push_back(receive); - } - - Operation* expected = concat.getOperation(); - for (SpatChannelReceiveOp receive : llvm::reverse(receives)) { - Operation* previous = expected->getPrevNode(); - if (previous != receive.getOperation()) - return false; - expected = previous; - } - - auto concatType = dyn_cast(concat->getResult(0).getType()); - auto fragmentType = dyn_cast(receives.front()->getResult(0).getType()); - if (!concatType || !fragmentType) - return false; - if (!hasCompatibleConcatTypes(concatType, fragmentType, receives.size())) - return false; - +struct ScalarSourceReceivePlan { + ClassId targetClass = 0; SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - channelIds.reserve(receives.size()); - sourceCoreIds.reserve(receives.size()); - targetCoreIds.reserve(receives.size()); + SmallVector sourceCoreIds; + SmallVector targetCoreIds; +}; - for (SpatChannelReceiveOp receive : receives) { - if (receive->getResult(0).getType() != fragmentType) - return false; +SmallVector collectDestinationClassesForKeys(MaterializerState& state, ArrayRef keys) { + SmallVector destinations; - int64_t channelId = 0; - int64_t sourceCoreId = 0; - int64_t targetCoreId = 0; - if (!getReceiveMetadata(receive, channelId, sourceCoreId, targetCoreId)) - return false; + for (ProducerKey key : keys) + for (ClassId destinationClass : getDestinationClasses(state, key)) + destinations.push_back(destinationClass); - channelIds.push_back(channelId); - sourceCoreIds.push_back(sourceCoreId); - targetCoreIds.push_back(targetCoreId); - } - - Value replacement = createReceiveConcatLoop(state, - materializedClass.op, - receives.front().getOperation(), - concatType, - fragmentType, - channelIds, - sourceCoreIds, - targetCoreIds, - concat.getLoc()); - - concat->getResult(0).replaceAllUsesWith(replacement); - state.rewriter.eraseOp(concat.getOperation()); - - for (SpatChannelReceiveOp receive : llvm::reverse(receives)) - state.rewriter.eraseOp(receive.getOperation()); - - return true; + llvm::sort(destinations); + destinations.erase(std::unique(destinations.begin(), destinations.end()), destinations.end()); + return destinations; } -void compactReceiveConcats(MaterializerState& state) { - SmallVector, 16> concatOps; +SmallVector emitScalarSourceSends(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef destinationClasses, + Value payload, + Location loc) { + assert(!sourceClass.isBatch && "scalar-source send planning expects a scalar source class"); - for (MaterializedClass& materializedClass : state.classes) - materializedClass.op->walk([&](tensor::ConcatOp concat) { concatOps.push_back({&materializedClass, concat}); }); + int32_t sourceCpu = static_cast(sourceClass.cpus.front()); - for (auto [materializedClass, concat] : concatOps) - compactReceiveConcat(state, *materializedClass, concat); + size_t messageCount = 0; + for (ClassId destinationClass : destinationClasses) { + if (destinationClass == sourceClass.id) + continue; + + MaterializedClass& targetClass = state.classes[destinationClass]; + messageCount += targetClass.isBatch ? targetClass.cpus.size() : 1; + } + + SmallVector allChannelIds; + SmallVector allSourceCoreIds; + SmallVector allTargetCoreIds; + allChannelIds.reserve(messageCount); + allSourceCoreIds.reserve(messageCount); + allTargetCoreIds.reserve(messageCount); + + SmallVector receivePlans; + receivePlans.reserve(destinationClasses.size()); + + for (ClassId destinationClass : destinationClasses) { + if (destinationClass == sourceClass.id) + continue; + + MaterializedClass& targetClass = state.classes[destinationClass]; + + ScalarSourceReceivePlan plan; + plan.targetClass = destinationClass; + + auto appendMessage = [&](int32_t targetCpu) { + int64_t channelId = state.nextChannelId++; + + plan.channelIds.push_back(channelId); + plan.sourceCoreIds.push_back(sourceCpu); + plan.targetCoreIds.push_back(targetCpu); + + allChannelIds.push_back(channelId); + allSourceCoreIds.push_back(sourceCpu); + allTargetCoreIds.push_back(targetCpu); + }; + + if (!targetClass.isBatch) + appendMessage(static_cast(targetClass.cpus.front())); + else + for (CpuId targetCpu : targetClass.cpus) + appendMessage(static_cast(targetCpu)); + + receivePlans.push_back(std::move(plan)); + } + + if (!allChannelIds.empty()) + appendSend(state, sourceClass, payload, allChannelIds, allSourceCoreIds, allTargetCoreIds, loc); + + return receivePlans; +} + +LogicalResult emitScalarSourceCommunication( + MaterializerState& state, MaterializedClass& sourceClass, ArrayRef keys, Value payload, Location loc) { + assert(!sourceClass.isBatch && "scalar-source communication expects a scalar source class"); + + for (ProducerKey key : keys) + state.availableValues.record(key, sourceClass.id, payload); + + SmallVector destinationClasses = collectDestinationClassesForKeys(state, keys); + SmallVector receivePlans = + emitScalarSourceSends(state, sourceClass, destinationClasses, payload, loc); + + for (const ScalarSourceReceivePlan& plan : receivePlans) { + MaterializedClass& targetClass = state.classes[plan.targetClass]; + + Value received = appendReceive( + state, targetClass, payload.getType(), plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds, loc); + + for (ProducerKey key : keys) + state.availableValues.record(key, targetClass.id, received); + } + + return success(); } LogicalResult emitClassToClassCommunication(MaterializerState& state, @@ -1238,63 +1488,19 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, Location loc) { if (sourceClass.id == targetClass.id) { for (ProducerKey key : keys) - state.availableValues[key][targetClass.id] = payload; + state.availableValues.record(key, targetClass.id, payload); return success(); } - if (!sourceClass.isBatch && !targetClass.isBatch) { - int64_t channelId = state.nextChannelId++; - int32_t sourceCpu = static_cast(sourceClass.cpus.front()); - int32_t targetCpu = static_cast(targetClass.cpus.front()); + if (!sourceClass.isBatch) + return sourceClass.op->emitError("scalar-source communication must be emitted through the scalar fanout planner"); - SmallVector channelIds {channelId}; - SmallVector sourceCoreIds {sourceCpu}; - SmallVector targetCoreIds {targetCpu}; - - appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); - Value received = - appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); - - for (ProducerKey key : keys) - state.availableValues[key][targetClass.id] = received; - return success(); - } - - if (!sourceClass.isBatch && targetClass.isBatch) { - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - channelIds.reserve(targetClass.cpus.size()); - sourceCoreIds.reserve(targetClass.cpus.size()); - targetCoreIds.reserve(targetClass.cpus.size()); - - int32_t sourceCpu = static_cast(sourceClass.cpus.front()); - for (CpuId targetCpu : targetClass.cpus) { - channelIds.push_back(state.nextChannelId++); - sourceCoreIds.push_back(sourceCpu); - targetCoreIds.push_back(static_cast(targetCpu)); - } - - appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); - Value received = - appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); - - for (ProducerKey key : keys) - state.availableValues[key][targetClass.id] = received; - return success(); - } - - if (sourceClass.isBatch && !targetClass.isBatch) { + if (!targetClass.isBatch) { std::optional packedKey = getContiguousProducerKeyForKeys(keys); if (!packedKey) return sourceClass.op->emitError( "cannot materialize batch-to-scalar communication because source lanes are not contiguous"); - FailureOr packedType = getPackedBatchTensorType(payload.getType(), keys.size()); - if (failed(packedType)) - return sourceClass.op->emitError( - "cannot materialize batch-to-scalar communication for non-static ranked tensor payload"); - SmallVector channelIds; SmallVector sourceCoreIds; SmallVector targetCoreIds; @@ -1310,39 +1516,34 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, } appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); - Value received = appendPackedScalarReceives( - state, targetClass, payload.getType(), *packedType, channelIds, sourceCoreIds, targetCoreIds, loc); - - state.availableValues[*packedKey][targetClass.id] = received; - return success(); + return registerLazyPackedScalarReceives( + state, sourceClass, targetClass, keys, payload.getType(), channelIds, sourceCoreIds, targetCoreIds); } - if (sourceClass.isBatch && targetClass.isBatch) { - if (sourceClass.cpus.size() != targetClass.cpus.size()) - return sourceClass.op->emitError( - "cannot materialize batch communication between equivalence classes of different sizes"); + if (sourceClass.cpus.size() != targetClass.cpus.size()) + return sourceClass.op->emitError( + "cannot materialize batch communication between equivalence classes of different sizes"); - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - channelIds.reserve(sourceClass.cpus.size()); - sourceCoreIds.reserve(sourceClass.cpus.size()); - targetCoreIds.reserve(targetClass.cpus.size()); + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + channelIds.reserve(sourceClass.cpus.size()); + sourceCoreIds.reserve(sourceClass.cpus.size()); + targetCoreIds.reserve(targetClass.cpus.size()); - for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { - channelIds.push_back(state.nextChannelId++); - sourceCoreIds.push_back(static_cast(sourceCpu)); - targetCoreIds.push_back(static_cast(targetClass.cpus[lane])); - } - - appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); - Value received = - appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); - - for (ProducerKey key : keys) - state.availableValues[key][targetClass.id] = received; - return success(); + for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { + channelIds.push_back(state.nextChannelId++); + sourceCoreIds.push_back(static_cast(sourceCpu)); + targetCoreIds.push_back(static_cast(targetClass.cpus[lane])); } + + appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); + Value received = appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); + + for (ProducerKey key : keys) + state.availableValues.record(key, targetClass.id, received); + + return success(); } LogicalResult @@ -1384,24 +1585,7 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(payloadType.getRank()); - sizes.reserve(payloadType.getRank()); - strides.reserve(payloadType.getRank()); - - offsets.push_back(*laneArg); - sizes.push_back(state.rewriter.getIndexAttr(1)); - strides.push_back(state.rewriter.getIndexAttr(1)); - - for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) { - offsets.push_back(state.rewriter.getIndexAttr(0)); - sizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim))); - strides.push_back(state.rewriter.getIndexAttr(1)); - } - - tensor::ParallelInsertSliceOp::create(state.rewriter, payload.getLoc(), payload, *outputArg, offsets, sizes, strides); + createDim0ParallelInsertSlice(state, payload.getLoc(), payload, *outputArg, *laneArg); return success(); } @@ -1423,14 +1607,10 @@ LogicalResult emitOutputFanout(MaterializerState& state, return success(); if (!sourceClass.isBatch) { - for (ClassId destinationClass : getDestinationClasses(state, keys.front())) - if (failed( - emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) - return failure(); - if (failed(emitHostCommunication(state, sourceClass, payload, originalOutput))) + if (failed(emitScalarSourceCommunication(state, sourceClass, keys, payload, loc))) return failure(); - state.availableValues[keys.front()][sourceClass.id] = payload; - return success(); + + return emitHostCommunication(state, sourceClass, payload, originalOutput); } if (!haveSameDestinationClasses(state, keys)) @@ -1445,33 +1625,433 @@ LogicalResult emitOutputFanout(MaterializerState& state, return failure(); for (ProducerKey key : keys) - state.availableValues[key][sourceClass.id] = payload; + state.availableValues.record(key, sourceClass.id, payload); + return success(); } -FailureOr materializeWholeBatchInput( - MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, Type resultType, Location loc) { - auto batch = dyn_cast_or_null(key.instance.op); - auto resultTensorType = dyn_cast(resultType); - if (!batch || !resultTensorType || resultTensorType.getRank() == 0) +struct WholeBatchAssemblyRange { + uint32_t laneStart = 0; + uint32_t laneCount = 0; +}; + +struct DirectWholeBatchFragment { + ProducerKey key; + Value fragment; +}; + +struct WholeBatchAssemblyPlan { + RankedTensorType resultType; + int64_t rowsPerLane = 0; + + SmallVector coveredRanges; + SmallVector packedRuns; + SmallVector directFragments; +}; + +bool wholeBatchRangeOverlaps(ArrayRef ranges, uint32_t laneStart, uint32_t laneCount) { + uint32_t laneEnd = laneStart + laneCount; + for (WholeBatchAssemblyRange range : ranges) { + uint32_t rangeEnd = range.laneStart + range.laneCount; + if (laneStart < rangeEnd && range.laneStart < laneEnd) + return true; + } + return false; +} + +bool wholeBatchLaneCovered(ArrayRef ranges, uint32_t lane) { + for (WholeBatchAssemblyRange range : ranges) + if (range.laneStart <= lane && lane < range.laneStart + range.laneCount) + return true; + return false; +} + +void recordWholeBatchCoverage(WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) { + plan.coveredRanges.push_back({laneStart, laneCount}); +} + +LogicalResult +validateWholeBatchFragmentType(RankedTensorType resultType, RankedTensorType fragmentType, int64_t expectedRows) { + if (!fragmentType.hasStaticShape()) + return failure(); + if (fragmentType.getRank() != resultType.getRank()) + return failure(); + if (fragmentType.getDimSize(0) != expectedRows) return failure(); - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + for (int64_t dim = 1; dim < resultType.getRank(); ++dim) + if (fragmentType.getDimSize(dim) != resultType.getDimSize(dim)) + return failure(); - uint32_t batchLaneCount = batch.getLaneCount(); - SmallVector fragments; + return success(); +} + +// ----------------------------------------------------------------------------- +// Packed run tensor assembly helpers. +// ----------------------------------------------------------------------------- + +Value insertFragmentIntoWholeBatch( + MaterializerState& state, Value fragment, Value destination, OpFoldResult firstOffset, Location loc) { + return createDim0InsertSlice(state, loc, fragment, destination, firstOffset); +} + +Value extractPackedSlotForIndex(MaterializerState& state, + Operation* anchor, + Value packed, + RankedTensorType slotPackedType, + Value slotIndex, + Location loc) { + Value firstOffset = scaleIndexByDim0Size(state, anchor, slotIndex, slotPackedType.getDimSize(0), loc); + return createDim0ExtractSlice(state, loc, packed, firstOffset, slotPackedType.getDimSize(0)); +} + +SmallVector flattenPackedScalarRunKeys(const PackedScalarRunValue& run) { + SmallVector keys; + for (const PackedScalarRunSlot& slot : run.slots) + llvm::append_range(keys, slot.keys); + return keys; +} + +FailureOr> cloneBatchBodyForLane(MaterializerState& state, + MaterializedClass& targetClass, + const ComputeInstance& instance, + Value laneValue, + ArrayRef resultIndices); + +FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& state, + MaterializedClass& targetClass, + PackedScalarRunValue& run, + Location loc) { + assert(isDeferredLocalPackedScalarRun(run) && "expected deferred local packed scalar run"); + + SmallVector keys = flattenPackedScalarRunKeys(run); + if (keys.empty()) + return failure(); + + FailureOr packedType = getPackedBatchTensorType(run.fragmentType, keys.size()); + if (failed(packedType)) + return targetClass.op->emitError("cannot materialize deferred local packed run for non-static ranked tensor"); + + SmallVector sourceLanes; + sourceLanes.reserve(keys.size()); + for (ProducerKey key : keys) { + if (key.instance.laneCount != 1) + return failure(); + sourceLanes.push_back(key.instance.laneStart); + } + + SmallVector resultIndices {run.resultIndex}; + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value init = + tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult(); + + Value lowerBound = createIndexConstant(state, targetClass.op, 0); + Value upperBound = createIndexConstant(state, targetClass.op, static_cast(keys.size())); + Value step = createIndexConstant(state, targetClass.op, 1); + + auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init}); + + Block* body = loop.getBody(); + if (!body->empty()) + if (auto yield = dyn_cast(body->back())) + state.rewriter.eraseOp(yield); + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPointToEnd(body); + + Value loopIndex = loop.getInductionVar(); + Value acc = body->getArgument(1); + + Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); + + FailureOr> produced = + cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices); + if (failed(produced)) + return failure(); + + if (produced->size() != 1) + return failure(); + + Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, run.fragmentType.getDimSize(0), loc); + Value next = createDim0InsertSlice(state, loc, produced->front(), acc, firstOffset); + + scf::YieldOp::create(state.rewriter, loc, next); + + run.packed = loop.getResult(0); + return run.packed; +} + +FailureOr insertDeferredLocalPackedScalarRunIntoWholeBatch(MaterializerState& state, + MaterializedClass& targetClass, + Value destination, + const WholeBatchAssemblyPlan& plan, + PackedScalarRunValue& run, + Location loc) { + assert(isDeferredLocalPackedScalarRun(run) && "expected deferred local packed scalar run"); + + if (run.fragmentType.getDimSize(0) != plan.rowsPerLane) + return failure(); + + SmallVector keys = flattenPackedScalarRunKeys(run); + if (keys.empty()) + return failure(); + + SmallVector sourceLanes; + SmallVector outputOffsets; + sourceLanes.reserve(keys.size()); + outputOffsets.reserve(keys.size()); + + for (ProducerKey key : keys) { + if (key.instance.laneCount != 1) + return failure(); + + sourceLanes.push_back(key.instance.laneStart); + outputOffsets.push_back(static_cast(key.instance.laneStart) * plan.rowsPerLane); + } + + SmallVector resultIndices {run.resultIndex}; + + Value lowerBound = createIndexConstant(state, targetClass.op, 0); + Value upperBound = createIndexConstant(state, targetClass.op, static_cast(keys.size())); + Value step = createIndexConstant(state, targetClass.op, 1); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination}); + + Block* body = loop.getBody(); + if (!body->empty()) + if (auto yield = dyn_cast(body->back())) + state.rewriter.eraseOp(yield); + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPointToEnd(body); + + Value loopIndex = loop.getInductionVar(); + Value acc = body->getArgument(1); + + Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); + + FailureOr> produced = + cloneBatchBodyForLane(state, targetClass, keys.front().instance, sourceLane, resultIndices); + if (failed(produced)) + return failure(); + + if (produced->size() != 1) + return failure(); + + Value outputOffset = createIndexedIndexValue(state, targetClass.op, outputOffsets, loopIndex, loc); + Value next = insertFragmentIntoWholeBatch(state, produced->front(), acc, outputOffset, loc); + + scf::YieldOp::create(state.rewriter, loc, next); + return loop.getResult(0); +} + +FailureOr insertDeferredPackedScalarRunIntoWholeBatch(MaterializerState& state, + MaterializedClass& targetClass, + Value destination, + const WholeBatchAssemblyPlan& plan, + PackedScalarRunValue& run, + Location loc) { + assert(run.kind == PackedScalarRunKind::DeferredReceive && "expected deferred receive packed scalar run"); + + if (failed(validatePackedScalarRunMetadata(targetClass.op, run))) + return failure(); + + if (run.fragmentType.getDimSize(0) != plan.rowsPerLane) + return failure(); + + SmallVector outputOffsets; + outputOffsets.reserve(getPackedScalarRunReceiveCount(run)); + + for (const PackedScalarRunSlot& slot : run.slots) { + for (ProducerKey key : slot.keys) { + if (key.instance.laneCount == 0) + return failure(); + + outputOffsets.push_back(static_cast(key.instance.laneStart) * plan.rowsPerLane); + } + } + + if (outputOffsets.size() != run.channelIds.size()) + return failure(); + + Value lowerBound = createIndexConstant(state, targetClass.op, 0); + Value upperBound = createIndexConstant(state, targetClass.op, static_cast(run.channelIds.size())); + Value step = createIndexConstant(state, targetClass.op, 1); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination}); + + Block* body = loop.getBody(); + if (!body->empty()) + if (auto yield = dyn_cast(body->back())) + state.rewriter.eraseOp(yield); + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPointToEnd(body); + + Value index = loop.getInductionVar(); + Value acc = body->getArgument(1); + + Value channelId = createIndexedIndexValue(state, targetClass.op, run.channelIds, index, loc); + Value sourceCoreId = createIndexedIndexValue(state, targetClass.op, run.sourceCoreIds, index, loc); + Value targetCoreId = createIndexedIndexValue(state, targetClass.op, run.targetCoreIds, index, loc); + + Value received = + SpatChannelReceiveOp::create(state.rewriter, loc, run.fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); + + Value outputOffset = createIndexedIndexValue(state, targetClass.op, outputOffsets, index, loc); + Value next = insertFragmentIntoWholeBatch(state, received, acc, outputOffset, loc); + + scf::YieldOp::create(state.rewriter, loc, next); + return loop.getResult(0); +} + +FailureOr insertPackedScalarRunIntoWholeBatch(MaterializerState& state, + MaterializedClass& targetClass, + Value destination, + const WholeBatchAssemblyPlan& plan, + PackedScalarRunValue& run, + Location loc) { + if (run.slots.empty()) + return destination; + + if (!run.packed) { + if (isDeferredLocalPackedScalarRun(run)) + return insertDeferredLocalPackedScalarRunIntoWholeBatch(state, targetClass, destination, plan, run, loc); + return insertDeferredPackedScalarRunIntoWholeBatch(state, targetClass, destination, plan, run, loc); + } + + auto sourceBatch = dyn_cast_or_null(run.sourceOp); + if (!sourceBatch) + return failure(); + + int64_t batchLaneCount = sourceBatch.getLaneCount(); + if (batchLaneCount <= 0) + return failure(); + + if (run.fragmentType.getDimSize(0) != plan.rowsPerLane) + return failure(); + + size_t slotLaneCount = run.slots.front().keys.size(); + if (slotLaneCount == 0) + return failure(); + + FailureOr slotPackedType = getPackedBatchTensorType(run.fragmentType, slotLaneCount); + if (failed(slotPackedType)) + return failure(); + + SmallVector slotRowOffsets; + slotRowOffsets.reserve(run.slots.size()); + + for (const PackedScalarRunSlot& slot : run.slots) { + if (slot.keys.size() != slotLaneCount) + return failure(); + + std::optional slotKey = getContiguousProducerKeyForKeys(slot.keys); + if (!slotKey) + return failure(); + + slotRowOffsets.push_back(static_cast(slotKey->instance.laneStart) * plan.rowsPerLane); + } + + Value lowerBound = createIndexConstant(state, targetClass.op, 0); + Value upperBound = createIndexConstant(state, targetClass.op, static_cast(run.slots.size())); + Value step = createIndexConstant(state, targetClass.op, 1); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {destination}); + + Block* body = loop.getBody(); + if (!body->empty()) + if (auto yield = dyn_cast(body->back())) + state.rewriter.eraseOp(yield); + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPointToEnd(body); + + Value slotIndex = loop.getInductionVar(); + Value acc = body->getArgument(1); + + Value slotPacked = extractPackedSlotForIndex(state, targetClass.op, run.packed, *slotPackedType, slotIndex, loc); + Value outputOffset = createIndexedIndexValue(state, targetClass.op, slotRowOffsets, slotIndex, loc); + Value next = insertFragmentIntoWholeBatch(state, slotPacked, acc, outputOffset, loc); + + scf::YieldOp::create(state.rewriter, loc, next); + return loop.getResult(0); +} + +LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state, + MaterializedClass& targetClass, + ProducerKey key, + WholeBatchAssemblyPlan& plan) { + for (PackedScalarRunValue& run : state.availableValues.getPackedScalarRuns()) { + if (run.targetClass != targetClass.id) + continue; + if (run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) + continue; + + SmallVector runRanges; + runRanges.reserve(run.slots.size()); + + for (const PackedScalarRunSlot& slot : run.slots) { + std::optional slotKey = getContiguousProducerKeyForKeys(slot.keys); + if (!slotKey) + return failure(); + + if (wholeBatchRangeOverlaps(plan.coveredRanges, slotKey->instance.laneStart, slotKey->instance.laneCount)) + return failure(); + + runRanges.push_back({slotKey->instance.laneStart, slotKey->instance.laneCount}); + } + + plan.packedRuns.push_back(&run); + + for (WholeBatchAssemblyRange range : runRanges) + recordWholeBatchCoverage(plan, range.laneStart, range.laneCount); + } + + return success(); +} + +LogicalResult collectDirectFragmentsForWholeBatchInput(MaterializerState& state, + MaterializedClass& targetClass, + SpatComputeBatch batch, + ProducerKey key, + WholeBatchAssemblyPlan& plan) { + uint32_t batchLaneCount = static_cast(batch.getLaneCount()); uint32_t lane = 0; while (lane < batchLaneCount) { + if (wholeBatchLaneCovered(plan.coveredRanges, lane)) { + ++lane; + continue; + } + bool foundFragment = false; for (uint32_t laneCount = batchLaneCount - lane; laneCount != 0; --laneCount) { + if (wholeBatchRangeOverlaps(plan.coveredRanges, lane, laneCount)) + continue; + ProducerKey candidate = getBatchLaneProducerKey(batch, lane, laneCount, key.resultIndex); - std::optional fragment = lookupAvailableValue(state, candidate, targetClass.id); + std::optional fragment = state.availableValues.lookupExact(candidate, targetClass.id); if (!fragment) continue; - fragments.push_back(*fragment); + auto fragmentType = dyn_cast(fragment->getType()); + if (!fragmentType) + return failure(); + + int64_t expectedRows = plan.rowsPerLane * static_cast(laneCount); + if (failed(validateWholeBatchFragmentType(plan.resultType, fragmentType, expectedRows))) + return failure(); + + plan.directFragments.push_back({candidate, *fragment}); + recordWholeBatchCoverage(plan, lane, laneCount); + lane += laneCount; foundFragment = true; break; @@ -1481,20 +2061,78 @@ FailureOr materializeWholeBatchInput( return failure(); } - if (fragments.empty()) + return success(); +} + +FailureOr buildWholeBatchAssemblyPlan(MaterializerState& state, + MaterializedClass& targetClass, + ProducerKey key, + Type resultType) { + auto batch = dyn_cast_or_null(key.instance.op); + auto resultTensorType = dyn_cast(resultType); + if (!batch || !resultTensorType || !resultTensorType.hasStaticShape() || resultTensorType.getRank() == 0) return failure(); - Value result = fragments.front(); - if (fragments.size() != 1) - result = tensor::ConcatOp::create(state.rewriter, loc, 0, ValueRange(fragments)).getResult(); + uint32_t batchLaneCount = static_cast(batch.getLaneCount()); + if (batchLaneCount == 0 || resultTensorType.getDimSize(0) % static_cast(batchLaneCount) != 0) + return failure(); - if (result.getType() != resultType) - result = tensor::CastOp::create(state.rewriter, loc, resultType, result).getResult(); + WholeBatchAssemblyPlan plan; + plan.resultType = resultTensorType; + plan.rowsPerLane = resultTensorType.getDimSize(0) / static_cast(batchLaneCount); - state.availableValues[key][targetClass.id] = result; + if (failed(collectPackedRunsForWholeBatchInput(state, targetClass, key, plan))) + return failure(); + + if (failed(collectDirectFragmentsForWholeBatchInput(state, targetClass, batch, key, plan))) + return failure(); + + return plan; +} + +FailureOr emitWholeBatchAssemblyPlan(MaterializerState& state, + MaterializedClass& targetClass, + ProducerKey key, + WholeBatchAssemblyPlan& plan, + Location loc) { + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value result = + tensor::EmptyOp::create(state.rewriter, loc, plan.resultType.getShape(), plan.resultType.getElementType()) + .getResult(); + + for (PackedScalarRunValue* run : plan.packedRuns) { + FailureOr updated = insertPackedScalarRunIntoWholeBatch(state, targetClass, result, plan, *run, loc); + if (failed(updated)) + return failure(); + + result = *updated; + } + + for (const DirectWholeBatchFragment& fragment : plan.directFragments) { + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + + int64_t rowOffset = static_cast(fragment.key.instance.laneStart) * plan.rowsPerLane; + Value outputOffset = createIndexConstant(state, targetClass.op, rowOffset); + result = insertFragmentIntoWholeBatch(state, fragment.fragment, result, outputOffset, loc); + } + + state.availableValues.record(key, targetClass.id, result); return result; } +// ----------------------------------------------------------------------------- +// Run materialization helpers. +// ----------------------------------------------------------------------------- + +FailureOr materializeWholeBatchInput( + MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, Type resultType, Location loc) { + FailureOr plan = buildWholeBatchAssemblyPlan(state, targetClass, key, resultType); + if (failed(plan)) + return failure(); + + return emitWholeBatchAssemblyPlan(state, targetClass, key, *plan, loc); +} + FailureOr resolveInputValue(MaterializerState& state, MaterializedClass& targetClass, Value input, @@ -1503,7 +2141,7 @@ FailureOr resolveInputValue(MaterializerState& state, return input; if (std::optional producer = getProducerKey(input, &consumerInstance)) { - if (std::optional value = lookupAvailableValue(state, *producer, targetClass.id)) + if (std::optional value = state.availableValues.lookup(state, *producer, targetClass.id)) return *value; if (isWholeBatchProducerKey(*producer)) @@ -1595,6 +2233,35 @@ SmallVector collectMappedBatchOutputs(SpatComputeBatch batch, IRMappin return outputs; } +SmallVector collectBatchOutputFragmentTypes(SpatComputeBatch batch) { + SmallVector types(batch.getNumResults(), Type {}); + auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); + if (!inParallel) + return types; + + auto firstOutputArg = batch.getOutputArgument(0); + if (!firstOutputArg) + return types; + + for (Operation& op : inParallel.getRegion().front()) { + auto insert = dyn_cast(&op); + if (!insert) + continue; + + auto outputArg = dyn_cast(insert.getDest()); + if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) + continue; + + unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); + if (resultIndex >= types.size()) + continue; + + types[resultIndex] = insert.getSource().getType(); + } + + return types; +} + FailureOr> cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef peers) { assert(!peers.empty() && "expected at least one peer instance"); @@ -1624,10 +2291,14 @@ cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, Arra mapper.map(*laneArg, createOriginalLaneValue(state, targetClass, peers, loc)); } + OpBuilder::InsertPoint cloneInsertionPoint = state.rewriter.saveInsertionPoint(); + mapWeights(state, targetClass, instance, mapper); if (failed(mapInputs(state, targetClass, instance, mapper))) return failure(); + state.rewriter.restoreInsertionPoint(cloneInsertionPoint); + Block& sourceBlock = getComputeInstanceTemplateBlock(instance); for (Operation& op : sourceBlock.without_terminator()) { Operation* cloned = state.rewriter.clone(op, mapper); @@ -1662,6 +2333,825 @@ cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, Arra return outputs; } +bool sameDestinationClasses(ArrayRef lhs, ArrayRef rhs) { + if (lhs.size() != rhs.size()) + return false; + for (auto [lhsClass, rhsClass] : llvm::zip(lhs, rhs)) + if (lhsClass != rhsClass) + return false; + return true; +} + +SmallVector +collectDestinationClassesForRun(MaterializerState& state, ArrayRef run, size_t resultIndex) { + SmallVector destinations; + + for (const MaterializationRunSlot& slot : run) { + for (const ComputeInstance& peer : slot.peers) { + ProducerKey key {peer, resultIndex}; + for (ClassId destinationClass : getDestinationClasses(state, key)) + if (!llvm::is_contained(destinations, destinationClass)) + destinations.push_back(destinationClass); + } + } + + llvm::sort(destinations); + return destinations; +} + +SmallVector groupBatchRunOutputsByDestination(MaterializerState& state, + ArrayRef run) { + assert(!run.empty() && "expected non-empty materialization run"); + assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); + + SmallVector groups; + SmallVector outputs = getComputeInstanceOutputValues(run.front().peers.front()); + + for (auto [resultIndex, output] : llvm::enumerate(outputs)) { + SmallVector destinations = collectDestinationClassesForRun(state, run, resultIndex); + + auto existingGroup = llvm::find_if(groups, [&](const OutputDestinationGroup& group) { + return sameDestinationClasses(group.destinationClasses, destinations); + }); + + if (existingGroup != groups.end()) { + existingGroup->resultIndices.push_back(resultIndex); + continue; + } + + OutputDestinationGroup group; + group.resultIndices.push_back(resultIndex); + group.destinationClasses = std::move(destinations); + groups.push_back(std::move(group)); + } + + return groups; +} + +FailureOr getPackedRunTensorType(Type elementType, size_t runSize) { + auto tensorType = dyn_cast(elementType); + if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0) + return failure(); + + SmallVector shape(tensorType.getShape()); + shape[0] *= static_cast(runSize); + return RankedTensorType::get(shape, tensorType.getElementType(), tensorType.getEncoding()); +} + +LogicalResult registerDeferredLocalPackedRunValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef keys, + Type fragmentType, + Location loc) { + if (keys.empty()) + return success(); + + auto rankedFragmentType = dyn_cast(fragmentType); + if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) + return materializedClass.op->emitError("deferred local packed run expects static ranked fragment type"); + + Operation* sourceOp = keys.front().instance.op; + size_t resultIndex = keys.front().resultIndex; + + for (ProducerKey key : keys) { + if (key.instance.op != sourceOp || key.resultIndex != resultIndex) + return materializedClass.op->emitError("deferred local packed run expects one producer result"); + + if (key.instance.laneCount != 1) + return materializedClass.op->emitError("deferred local packed run expects one lane per fragment"); + } + + PackedScalarRunValue packedRun; + packedRun.targetClass = materializedClass.id; + packedRun.sourceOp = sourceOp; + packedRun.resultIndex = resultIndex; + packedRun.kind = PackedScalarRunKind::DeferredLocalCompute; + packedRun.fragmentType = rankedFragmentType; + + packedRun.slots.reserve(keys.size()); + for (ProducerKey key : keys) { + PackedScalarRunSlot slot; + slot.keys.push_back(key); + packedRun.slots.push_back(std::move(slot)); + } + + state.availableValues.recordPackedRun(std::move(packedRun)); + return success(); +} + +LogicalResult registerPackedRunValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef keys, + Value packed, + Type fragmentType, + Location loc) { + if (keys.empty()) + return success(); + + FailureOr expectedPackedType = getPackedRunTensorType(fragmentType, keys.size()); + if (failed(expectedPackedType)) + return materializedClass.op->emitError("packed run registration expects static ranked fragment type"); + + if (packed.getType() != *expectedPackedType) + return materializedClass.op->emitError("packed run value has unexpected tensor type"); + + auto rankedFragmentType = dyn_cast(fragmentType); + if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) + return materializedClass.op->emitError("packed run registration expects static ranked fragment type"); + + Operation* sourceOp = keys.front().instance.op; + size_t resultIndex = keys.front().resultIndex; + + for (ProducerKey key : keys) { + if (key.instance.op != sourceOp || key.resultIndex != resultIndex) + return materializedClass.op->emitError("packed run registration expects one producer result"); + if (key.instance.laneCount != 1) + return materializedClass.op->emitError("packed run registration expects one lane per packed fragment"); + } + + if (std::optional contiguousKey = getContiguousProducerKeyForKeys(keys)) { + state.availableValues.record(*contiguousKey, materializedClass.id, packed); + return success(); + } + + PackedScalarRunValue packedRun; + packedRun.targetClass = materializedClass.id; + packedRun.sourceOp = sourceOp; + packedRun.resultIndex = resultIndex; + packedRun.packed = packed; + packedRun.kind = PackedScalarRunKind::Materialized; + packedRun.fragmentType = rankedFragmentType; + + packedRun.slots.reserve(keys.size()); + for (ProducerKey key : keys) { + PackedScalarRunSlot slot; + slot.keys.push_back(key); + packedRun.slots.push_back(std::move(slot)); + } + + state.availableValues.recordPackedRun(std::move(packedRun)); + return success(); +} + +LogicalResult emitPackedRunFanout(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef destinationClasses, + ArrayRef keys, + Value packed, + Type fragmentType, + Location loc) { + assert(!sourceClass.isBatch && "packed run fanout expects a scalar source class"); + + SmallVector receivePlans = + emitScalarSourceSends(state, sourceClass, destinationClasses, packed, loc); + + for (const ScalarSourceReceivePlan& plan : receivePlans) { + MaterializedClass& targetClass = state.classes[plan.targetClass]; + + Value received = + appendReceive(state, targetClass, packed.getType(), plan.channelIds, plan.sourceCoreIds, plan.targetCoreIds, loc); + + if (failed(registerPackedRunValue(state, targetClass, keys, received, fragmentType, loc))) + return failure(); + } + + return success(); +} + +FailureOr> cloneBatchBodyForLane(MaterializerState& state, + MaterializedClass& targetClass, + const ComputeInstance& instance, + Value laneValue, + ArrayRef resultIndices) { + auto batch = dyn_cast(instance.op); + if (!batch) + return failure(); + + IRMapping mapper; + auto sourceLaneArg = batch.getLaneArgument(); + if (!sourceLaneArg) + return batch.emitOpError("expected source compute_batch lane block argument"); + + mapper.map(*sourceLaneArg, laneValue); + + OpBuilder::InsertPoint cloneInsertionPoint = state.rewriter.saveInsertionPoint(); + + mapWeights(state, targetClass, instance, mapper); + if (failed(mapInputs(state, targetClass, instance, mapper))) + return failure(); + + state.rewriter.restoreInsertionPoint(cloneInsertionPoint); + + Block& sourceBlock = getComputeInstanceTemplateBlock(instance); + for (Operation& op : sourceBlock.without_terminator()) { + Operation* cloned = state.rewriter.clone(op, mapper); + for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults())) + mapper.map(oldResult, newResult); + } + + SmallVector allOutputs = collectMappedBatchOutputs(batch, mapper); + if (allOutputs.empty() && !resultIndices.empty()) + return batch.emitOpError("failed to recover source compute_batch outputs"); + + SmallVector selectedOutputs; + selectedOutputs.reserve(resultIndices.size()); + for (size_t resultIndex : resultIndices) { + if (resultIndex >= allOutputs.size() || !allOutputs[resultIndex]) + return batch.emitOpError("failed to recover selected compute_batch output"); + selectedOutputs.push_back(allOutputs[resultIndex]); + } + + return selectedOutputs; +} + +FailureOr> materializeBatchOutputGroupLoop(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run, + const OutputDestinationGroup& group) { + assert(!run.empty() && "expected non-empty batch run"); + assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); + + Operation* sourceOp = run.front().peers.front().op; + Location loc = sourceOp->getLoc(); + + if (run.size() == 1) { + if (run.front().peers.size() != 1) + return sourceOp->emitError("scalar batch output loop expects exactly one peer in singleton slot"); + + const ComputeInstance& instance = run.front().peers.front(); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value laneValue = createIndexConstant(state, targetClass.op, instance.laneStart); + return cloneBatchBodyForLane(state, targetClass, instance, laneValue, group.resultIndices); + } + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + + auto sourceBatch = cast(sourceOp); + SmallVector fragmentTypes = collectBatchOutputFragmentTypes(sourceBatch); + SmallVector initValues; + + for (size_t resultIndex : group.resultIndices) { + if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) + return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run"); + + Type fragmentType = fragmentTypes[resultIndex]; + FailureOr packedType = getPackedRunTensorType(fragmentType, run.size()); + if (failed(packedType)) + return sourceBatch.emitOpError("cannot materialize packed batch run for non-static ranked output"); + + initValues.push_back( + tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult()); + } + + SmallVector laneStarts; + laneStarts.reserve(run.size()); + + for (const MaterializationRunSlot& slot : run) { + if (slot.peers.size() != 1) + return sourceOp->emitError("scalar batch output loop expects exactly one peer per materialization slot"); + + const ComputeInstance& instance = slot.peers.front(); + if (instance.op != sourceOp) + return sourceOp->emitError("materialization run contains different source operations"); + + laneStarts.push_back(instance.laneStart); + } + + Value lowerBound = createIndexConstant(state, targetClass.op, 0); + Value upperBound = createIndexConstant(state, targetClass.op, static_cast(run.size())); + Value step = createIndexConstant(state, targetClass.op, 1); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange(initValues)); + + Block* body = loop.getBody(); + if (!body->empty()) + if (auto yield = dyn_cast(body->back())) + state.rewriter.eraseOp(yield); + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPointToEnd(body); + + Value loopIndex = loop.getInductionVar(); + Value sourceLane = createIndexedIndexValue(state, targetClass.op, laneStarts, loopIndex, loc); + + FailureOr> produced = + cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices); + if (failed(produced)) + return failure(); + + SmallVector yielded; + yielded.reserve(produced->size()); + + for (auto [outputIndex, output] : llvm::enumerate(*produced)) { + auto fragmentType = cast(output.getType()); + Value acc = body->getArgument(1 + outputIndex); + Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc); + yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset)); + } + + scf::YieldOp::create(state.rewriter, loc, yielded); + + SmallVector results; + results.reserve(loop.getNumResults()); + for (Value result : loop.getResults()) + results.push_back(result); + return results; +} + +SmallVector getMaterializationRunSlotOutputKeys(const MaterializationRunSlot& slot, + size_t resultIndex) { + SmallVector keys; + keys.reserve(slot.peers.size()); + for (const ComputeInstance& peer : slot.peers) + keys.push_back({peer, resultIndex}); + return keys; +} + +FailureOr> +getMaterializationRunSlotPeers(MaterializerState& state, MaterializedClass& targetClass, SlotId slot) { + if (targetClass.isBatch) + return getPeerInstances(state, targetClass, slot); + + auto instanceIt = state.cpuSlotToInstance.find({targetClass.cpus.front(), slot}); + if (instanceIt == state.cpuSlotToInstance.end()) + return failure(); + + return SmallVector {instanceIt->second}; +} + +FailureOr collectBatchMaterializationRun(MaterializerState& state, + MaterializedClass& targetClass, + SlotId startSlot, + Operation* sourceOp) { + MaterializationRun run; + + for (SlotId slot = startSlot;; ++slot) { + ClassSlotKey classSlot {targetClass.id, slot}; + if (state.materializedSlots.contains(classSlot)) + break; + + FailureOr> peers = getMaterializationRunSlotPeers(state, targetClass, slot); + if (failed(peers) || peers->empty()) + break; + + bool validSlot = true; + for (const ComputeInstance& peer : *peers) { + if (peer.op != sourceOp || !isa(peer.op)) { + validSlot = false; + break; + } + + if (peer.laneCount != 1) + return peer.op->emitError("batch run materialization expects one scheduled source lane per materialized lane"); + } + + if (!validSlot) + break; + + MaterializationRunSlot runSlot; + runSlot.peers = std::move(*peers); + run.push_back(std::move(runSlot)); + } + + if (run.empty()) + return failure(); + + return run; +} + +SmallVector getMaterializationRunOutputKeys(ArrayRef run, size_t resultIndex) { + SmallVector keys; + for (const MaterializationRunSlot& slot : run) + llvm::append_range(keys, getMaterializationRunSlotOutputKeys(slot, resultIndex)); + return keys; +} + +SmallVector getFirstMaterializationRunOriginalOutputs(ArrayRef run) { + assert(!run.empty() && "expected non-empty materialization run"); + assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); + return getComputeInstanceOutputValues(run.front().peers.front()); +} + +Operation* getMaterializationRunSourceOp(ArrayRef run) { + assert(!run.empty() && "expected non-empty materialization run"); + assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); + return run.front().peers.front().op; +} + +Location getMaterializationRunLoc(ArrayRef run) { + return getMaterializationRunSourceOp(run)->getLoc(); +} + +bool hasMaterializationRunResultLiveExternalUse(MaterializerState& state, + ArrayRef run, + size_t resultIndex) { + for (const MaterializationRunSlot& slot : run) { + for (const ComputeInstance& peer : slot.peers) { + SmallVector outputs = getComputeInstanceOutputValues(peer); + if (resultIndex >= outputs.size()) + return true; + + if (hasLiveExternalUse(outputs[resultIndex], state.oldComputeOps)) + return true; + } + } + + return false; +} + +bool hasMaterializationRunGroupLiveExternalUse(MaterializerState& state, + ArrayRef run, + const OutputDestinationGroup& group) { + for (size_t resultIndex : group.resultIndices) + if (hasMaterializationRunResultLiveExternalUse(state, run, resultIndex)) + return true; + + return false; +} + +void markMaterializationRunSlots(MaterializerState& state, + ClassId classId, + SlotId startSlot, + ArrayRef run) { + for (auto slotIndex : llvm::seq(0, run.size())) + state.materializedSlots.insert({classId, startSlot + static_cast(slotIndex)}); +} + +LogicalResult materializeScalarBatchRun(MaterializerState& state, + MaterializedClass& targetClass, + SlotId startSlot, + ArrayRef run) { + assert(!targetClass.isBatch && "scalar batch run materialization expects scalar target class"); + assert(!run.empty() && "expected non-empty batch run"); + + markMaterializationRunSlots(state, targetClass.id, startSlot, run); + + SmallVector groups = groupBatchRunOutputsByDestination(state, run); + SmallVector firstOriginalOutputs = getFirstMaterializationRunOriginalOutputs(run); + + auto sourceBatch = cast(getMaterializationRunSourceOp(run)); + SmallVector fragmentTypes = collectBatchOutputFragmentTypes(sourceBatch); + Location loc = getMaterializationRunLoc(run); + + for (const OutputDestinationGroup& group : groups) { + if (run.size() > 1 && group.destinationClasses.empty() + && !hasMaterializationRunGroupLiveExternalUse(state, run, group)) { + for (size_t resultIndex : group.resultIndices) { + if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) + return sourceBatch.emitOpError("failed to recover per-lane output type for deferred local packed run"); + + SmallVector keys = getMaterializationRunOutputKeys(run, resultIndex); + if (failed(registerDeferredLocalPackedRunValue(state, targetClass, keys, fragmentTypes[resultIndex], loc))) + return failure(); + } + + continue; + } + + FailureOr> packedOutputs = materializeBatchOutputGroupLoop(state, targetClass, run, group); + if (failed(packedOutputs)) + return failure(); + + for (auto [groupOutputIndex, resultIndex] : llvm::enumerate(group.resultIndices)) { + Value packed = (*packedOutputs)[groupOutputIndex]; + if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) + return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run"); + + Type fragmentType = fragmentTypes[resultIndex]; + SmallVector keys = getMaterializationRunOutputKeys(run, resultIndex); + + if (run.size() == 1) { + if (failed(emitOutputFanout(state, targetClass, keys, packed, firstOriginalOutputs[resultIndex], loc))) + return failure(); + continue; + } + + if (failed(emitPackedRunFanout(state, targetClass, group.destinationClasses, keys, packed, fragmentType, loc))) + return failure(); + + if (failed(registerPackedRunValue(state, targetClass, keys, packed, fragmentType, loc))) + return failure(); + + auto rankedFragmentType = cast(fragmentType); + for (auto [runIndex, slot] : llvm::enumerate(run)) { + assert(slot.peers.size() == 1 && "scalar materialization run slot must contain exactly one peer"); + + SmallVector originalOutputs = getComputeInstanceOutputValues(slot.peers.front()); + Value originalOutput = originalOutputs[resultIndex]; + + if (!hasLiveExternalUse(originalOutput, state.oldComputeOps)) + continue; + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value fragment = getPackedSliceForRunIndex(state, targetClass.op, packed, rankedFragmentType, runIndex, loc); + + if (failed(emitHostCommunication(state, targetClass, fragment, originalOutput))) + return failure(); + } + } + } + + return success(); +} + +bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) { + for (const ComputeInstance& consumer : state.schedule.dominanceOrderCompute) { + auto cpuIt = state.schedule.computeToCpuMap.find(consumer); + if (cpuIt == state.schedule.computeToCpuMap.end()) + continue; + + if (state.cpuToClass.lookup(cpuIt->second) != classId) + continue; + + for (Value input : getComputeInstanceInputs(consumer)) { + std::optional producer = getProducerKey(input, &consumer); + if (!producer) + continue; + + for (ProducerKey expanded : expandWholeBatchProducerKey(*producer)) + if (expanded == producerKey) + return true; + } + } + + return false; +} + +bool canCompactBatchClassRun(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run) { + if (run.size() < 2) + return false; + if (run.front().peers.empty()) + return false; + + SmallVector outputs = getComputeInstanceOutputValues(run.front().peers.front()); + + for (auto [resultIndex, ignored] : llvm::enumerate(outputs)) { + for (const MaterializationRunSlot& slot : run) { + if (slot.peers.empty()) + return false; + + for (const ComputeInstance& peer : slot.peers) { + SmallVector peerOutputs = getComputeInstanceOutputValues(peer); + if (resultIndex >= peerOutputs.size()) + return false; + + Value originalOutput = peerOutputs[resultIndex]; + if (hasLiveExternalUse(originalOutput, state.oldComputeOps)) + return false; + + ProducerKey key {peer, resultIndex}; + if (hasSameClassConsumer(state, key, targetClass.id)) + return false; + } + } + } + + return true; +} + +Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targetClass, Value slotIndex, Location loc) { + auto batch = cast(targetClass.op); + auto laneArg = batch.getLaneArgument(); + assert(laneArg && "expected materialized compute_batch lane argument"); + + MLIRContext* context = state.func.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + AffineExpr d1 = getAffineDimExpr(1, context); + + int64_t laneCount = static_cast(targetClass.cpus.size()); + AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1); + return affine::AffineApplyOp::create(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}).getResult(); +} + +Value createBatchClassRunSourceLane(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run, + Value slotIndex, + Location loc) { + SmallVector sourceLanes; + sourceLanes.reserve(run.size() * targetClass.cpus.size()); + + for (const MaterializationRunSlot& slot : run) { + assert(slot.peers.size() == targetClass.cpus.size() && "expected one peer per materialized batch lane"); + for (const ComputeInstance& peer : slot.peers) + sourceLanes.push_back(peer.laneStart); + } + + Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); + return createIndexedIndexValue(state, targetClass.op, sourceLanes, flatIndex, loc); +} + +LogicalResult buildBatchRunSendPlans(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef run, + const OutputDestinationGroup& group, + SmallVectorImpl& plans) { + assert(sourceClass.isBatch && "batch run send planning expects a materialized batch source"); + + for (size_t resultIndex : group.resultIndices) { + for (ClassId destinationClass : group.destinationClasses) { + if (destinationClass == sourceClass.id) + return sourceClass.op->emitError("batch-target run compaction cannot handle same-class consumers"); + + MaterializedClass& targetClass = state.classes[destinationClass]; + + if (targetClass.isBatch && targetClass.cpus.size() != sourceClass.cpus.size()) + return sourceClass.op->emitError( + "cannot compact batch run communication between batch classes of different sizes"); + + BatchRunSendPlan plan; + plan.resultIndex = resultIndex; + plan.destinationClass = destinationClass; + + size_t messageCount = run.size() * sourceClass.cpus.size(); + plan.channelIds.reserve(messageCount); + plan.sourceCoreIds.reserve(messageCount); + plan.targetCoreIds.reserve(messageCount); + + for ([[maybe_unused]] const MaterializationRunSlot& slot : run) { + for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { + plan.channelIds.push_back(state.nextChannelId++); + plan.sourceCoreIds.push_back(static_cast(sourceCpu)); + + int32_t targetCpu = targetClass.isBatch ? static_cast(targetClass.cpus[lane]) + : static_cast(targetClass.cpus.front()); + plan.targetCoreIds.push_back(targetCpu); + } + } + + plans.push_back(std::move(plan)); + } + } + + return success(); +} + +void appendBatchRunSend(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const BatchRunSendPlan& plan, + Value flatIndex, + Location loc) { + assert(sourceClass.isBatch && "batch run send expects a materialized batch source"); + + Value channelId = createIndexedIndexValue(state, sourceClass.op, plan.channelIds, flatIndex, loc); + Value sourceCoreId = createIndexedIndexValue(state, sourceClass.op, plan.sourceCoreIds, flatIndex, loc); + Value targetCoreId = createIndexedIndexValue(state, sourceClass.op, plan.targetCoreIds, flatIndex, loc); + + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); +} + +ArrayRef sliceChannelsForRunSlot(const BatchRunSendPlan& plan, size_t slotIndex, size_t laneCount) { + return ArrayRef(plan.channelIds).slice(slotIndex * laneCount, laneCount); +} + +ArrayRef sliceSourcesForRunSlot(const BatchRunSendPlan& plan, size_t slotIndex, size_t laneCount) { + return ArrayRef(plan.sourceCoreIds).slice(slotIndex * laneCount, laneCount); +} + +ArrayRef sliceTargetsForRunSlot(const BatchRunSendPlan& plan, size_t slotIndex, size_t laneCount) { + return ArrayRef(plan.targetCoreIds).slice(slotIndex * laneCount, laneCount); +} + +LogicalResult appendPackedScalarRunReceives(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef run, + const BatchRunSendPlan& plan, + Type fragmentType, + Location loc) { + MaterializedClass& targetClass = state.classes[plan.destinationClass]; + assert(!targetClass.isBatch && "packed scalar run receives expect a scalar target class"); + + size_t laneCount = sourceClass.cpus.size(); + size_t receiveCount = run.size() * laneCount; + + if (receiveCount != plan.channelIds.size() || receiveCount != plan.sourceCoreIds.size() + || receiveCount != plan.targetCoreIds.size()) + return targetClass.op->emitError("inconsistent flattened batch run receive plan"); + + auto rankedFragmentType = dyn_cast(fragmentType); + if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) + return targetClass.op->emitError("packed scalar run receive expects static ranked fragment type"); + + PackedScalarRunValue packedRun; + packedRun.targetClass = targetClass.id; + packedRun.sourceOp = run.front().peers.front().op; + packedRun.resultIndex = plan.resultIndex; + packedRun.kind = PackedScalarRunKind::DeferredReceive; + packedRun.fragmentType = rankedFragmentType; + + packedRun.channelIds = plan.channelIds; + packedRun.sourceCoreIds = plan.sourceCoreIds; + packedRun.targetCoreIds = plan.targetCoreIds; + + packedRun.slots.reserve(run.size()); + for (const MaterializationRunSlot& slot : run) { + PackedScalarRunSlot packedSlot; + packedSlot.keys = getMaterializationRunSlotOutputKeys(slot, plan.resultIndex); + packedRun.slots.push_back(std::move(packedSlot)); + } + + if (failed(validatePackedScalarRunMetadata(targetClass.op, packedRun))) + return failure(); + + state.availableValues.recordPackedRun(std::move(packedRun)); + return success(); +} + +LogicalResult appendBatchRunReceives(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef run, + const BatchRunSendPlan& plan, + Type fragmentType, + Location loc) { + MaterializedClass& targetClass = state.classes[plan.destinationClass]; + size_t laneCount = sourceClass.cpus.size(); + + if (!targetClass.isBatch) + return appendPackedScalarRunReceives(state, sourceClass, run, plan, fragmentType, loc); + + for (auto [slotIndex, slot] : llvm::enumerate(run)) { + SmallVector keys = getMaterializationRunSlotOutputKeys(slot, plan.resultIndex); + + ArrayRef channelIds = sliceChannelsForRunSlot(plan, slotIndex, laneCount); + ArrayRef sourceCoreIds = sliceSourcesForRunSlot(plan, slotIndex, laneCount); + ArrayRef targetCoreIds = sliceTargetsForRunSlot(plan, slotIndex, laneCount); + + Value received = appendReceive(state, targetClass, fragmentType, channelIds, sourceCoreIds, targetCoreIds, loc); + for (ProducerKey key : keys) + state.availableValues.record(key, targetClass.id, received); + } + + return success(); +} + +LogicalResult materializeBatchClassRun(MaterializerState& state, + MaterializedClass& targetClass, + SlotId startSlot, + ArrayRef run) { + assert(targetClass.isBatch && "batch-target run materialization expects a materialized batch class"); + assert(!run.empty() && "expected non-empty batch-target run"); + + if (!canCompactBatchClassRun(state, targetClass, run)) + return failure(); + + markMaterializationRunSlots(state, targetClass.id, startSlot, run); + + SmallVector groups = groupBatchRunOutputsByDestination(state, run); + + auto sourceBatch = cast(run.front().peers.front().op); + SmallVector fragmentTypes = collectBatchOutputFragmentTypes(sourceBatch); + Location loc = sourceBatch.getLoc(); + + for (const OutputDestinationGroup& group : groups) { + SmallVector sendPlans; + if (failed(buildBatchRunSendPlans(state, targetClass, run, group, sendPlans))) + return failure(); + + Value lowerBound = createIndexConstant(state, targetClass.op, 0); + Value upperBound = createIndexConstant(state, targetClass.op, static_cast(run.size())); + Value step = createIndexConstant(state, targetClass.op, 1); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {}); + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPointToStart(loop.getBody()); + + Value slotIndex = loop.getInductionVar(); + Value sourceLane = createBatchClassRunSourceLane(state, targetClass, run, slotIndex, loc); + Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); + + FailureOr> produced = + cloneBatchBodyForLane(state, targetClass, run.front().peers.front(), sourceLane, group.resultIndices); + if (failed(produced)) + return failure(); + + for (const BatchRunSendPlan& plan : sendPlans) { + auto resultIt = llvm::find(group.resultIndices, plan.resultIndex); + if (resultIt == group.resultIndices.end()) + return targetClass.op->emitError("internal error: missing compacted batch run result"); + + size_t groupOutputIndex = static_cast(std::distance(group.resultIndices.begin(), resultIt)); + appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc); + } + + for (const BatchRunSendPlan& plan : sendPlans) { + if (plan.resultIndex >= fragmentTypes.size() || !fragmentTypes[plan.resultIndex]) + return sourceBatch.emitOpError("failed to recover per-lane output type for compacted batch run"); + + if (failed(appendBatchRunReceives(state, targetClass, run, plan, fragmentTypes[plan.resultIndex], loc))) + return failure(); + } + } + + return success(); +} + LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeInstance& instance) { auto cpuIt = state.schedule.computeToCpuMap.find(instance); if (cpuIt == state.schedule.computeToCpuMap.end()) @@ -1671,11 +3161,27 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns return instance.op->emitError("schedule materialization expected a CPU slot for every compute instance"); ClassId classId = state.cpuToClass.lookup(cpuIt->second); + MaterializedClass& targetClass = state.classes[classId]; + ClassSlotKey classSlot {classId, slotIt->second}; + if (state.materializedSlots.contains(classSlot)) + return success(); + + if (isa(instance.op)) { + FailureOr run = collectBatchMaterializationRun(state, targetClass, slotIt->second, instance.op); + + if (succeeded(run)) { + if (!targetClass.isBatch) + return materializeScalarBatchRun(state, targetClass, slotIt->second, *run); + + if (succeeded(materializeBatchClassRun(state, targetClass, slotIt->second, *run))) + return success(); + } + } + if (!state.materializedSlots.insert(classSlot).second) return success(); - MaterializedClass& targetClass = state.classes[classId]; FailureOr> peers = getPeerInstances(state, targetClass, slotIt->second); if (failed(peers)) return instance.op->emitError("failed to collect peer compute instances for equivalence class slot"); @@ -1699,6 +3205,53 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns return success(); } +Value createReceiveConcatLoop(MaterializerState& state, + Operation* anchor, + Operation* insertionPoint, + RankedTensorType concatType, + RankedTensorType fragmentType, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc) { + assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); + assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); + assert(!channelIds.empty() && "expected at least one receive"); + + Value lowerBound = createIndexConstant(state, anchor, 0); + Value upperBound = createIndexConstant(state, anchor, static_cast(channelIds.size())); + Value step = createIndexConstant(state, anchor, 1); + + state.rewriter.setInsertionPoint(insertionPoint); + Value init = + tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult(); + auto loop = scf::ForOp::create(state.rewriter, loc, lowerBound, upperBound, step, ValueRange {init}); + + Block* body = loop.getBody(); + if (!body->empty()) + if (auto yield = dyn_cast(body->back())) + state.rewriter.eraseOp(yield); + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPointToEnd(body); + + Value index = loop.getInductionVar(); + Value acc = body->getArgument(1); + + Value channelId = createIndexedIndexValue(state, anchor, channelIds, index, loc); + Value sourceCoreId = createIndexedIndexValue(state, anchor, sourceCoreIds, index, loc); + Value targetCoreId = createIndexedIndexValue(state, anchor, targetCoreIds, index, loc); + + Value received = + SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId).getOutput(); + + Value firstOffset = scaleIndexByDim0Size(state, anchor, index, fragmentType.getDimSize(0), loc); + Value next = createDim0InsertSlice(state, loc, received, acc, firstOffset); + scf::YieldOp::create(state.rewriter, loc, next); + + return loop.getResult(0); +} + void replaceHostUses(MaterializerState& state) { for (const auto& [oldValue, replacement] : state.hostReplacements) replaceLiveExternalUses(oldValue, replacement, state.oldComputeOps); @@ -1738,8 +3291,6 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch if (failed(materializeInstanceSlot(state, instance))) return failure(); - compactReceiveConcats(state); - replaceHostUses(state); if (failed(eraseOldComputeOps(state))) return failure(); diff --git a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp index 6b67cee..490ca59 100644 --- a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp +++ b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp @@ -2,6 +2,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -37,6 +38,7 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, OperationFolder& constantFolder, bool& hasFailure) { DenseMap>> materializedValues; + DominanceInfo dominance(coreOp); SmallVector ops; coreOp.getBody().front().walk([&](Operation* op) { if (!isa(op)) @@ -70,7 +72,7 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, auto& cachedByOffset = materializedValues[resolvedAddress->base]; auto& cachedByType = cachedByOffset[resolvedAddress->byteOffset]; auto cachedValue = cachedByType.find(originalType); - if (cachedValue != cachedByType.end()) { + if (cachedValue != cachedByType.end() && dominance.properlyDominates(cachedValue->second, op)) { operand.set(cachedValue->second); continue; }