diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 261dcf9..2856699 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -214,21 +214,29 @@ struct ProjectedBatchInputKeyInfo { static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { return lhs == rhs; } }; +struct ProjectedFragmentLayout { + RankedTensorType fragmentType; + SmallVector fragmentShape; + unsigned fragmentsPerLogicalSlot = 1; + unsigned payloadFragmentCount = 1; + SmallVector loopLowerBounds; + SmallVector loopSteps; + SmallVector loopTripCounts; +}; + struct ProjectedTransferDescriptor { ProjectedBatchInputKey inputKey; Operation* extractOp = nullptr; - RankedTensorType fragmentType; + ProjectedFragmentLayout layout; RankedTensorType payloadType; - unsigned sourceProjectedDim = 0; - unsigned fragmentsPerLane = 1; - SmallVector laneMajorProjectedOffsets; + SmallVector, 16> fragmentOffsets; + SmallVector, 4> fragmentOffsetsByDim; }; struct ProjectedExtractReplacement { Value payload; - RankedTensorType fragmentType; - unsigned fragmentsPerLane = 1; + ProjectedFragmentLayout layout; }; struct CloneIndexingContext { @@ -236,8 +244,30 @@ struct CloneIndexingContext { std::optional projectionSlotIndex; }; +struct StaticProjectedLoopInfo { + BlockArgument iv; + int64_t lowerBound = 0; + int64_t step = 1; + int64_t tripCount = 1; +}; + +struct AffineProjectedInputSliceMatch { + tensor::ExtractSliceOp extract; + RankedTensorType sourceType; + RankedTensorType fragmentType; + SmallVector fragmentShape; + SmallVector offsets; + SmallVector loops; +}; + struct MaterializerState; +FailureOr materializeProjectedExtractReplacement(MaterializerState& state, + MaterializedClass& targetClass, + tensor::ExtractSliceOp extract, + const ProjectedExtractReplacement& replacement, + std::optional projectionSlotIndex); + class AvailableValueStore { public: void record(ProducerKey key, ClassId classId, Value value) { exactValues[key][classId] = value; } @@ -276,6 +306,11 @@ struct MaterializerState { DenseMap, ProducerKeyInfo> producerDestClasses; DenseMap, ProducerKeyInfo> sameClassConsumers; + DenseMap projectedInputMatches; + DenseSet nonProjectedInputs; + DenseMap liveExternalUseCache; + DenseMap> batchOutputFragmentTypesCache; + DenseMap, llvm::DenseMapInfo> computeInstanceOutputsCache; DenseMap, ProducerKeyInfo> projectedTransfers; DenseMap> projectedExtractReplacements; AvailableValueStore availableValues; @@ -303,6 +338,16 @@ bool isInsideOldCompute(Operation* op, const DenseSet& oldComputeOps } bool hasLiveExternalUse(Value value, const DenseSet& oldComputeOps); +ArrayRef getComputeInstanceOutputValuesCached(MaterializerState& state, ComputeInstance instance); + +bool hasLiveExternalUseCached(MaterializerState& state, Value value) { + auto cached = state.liveExternalUseCache.find(value); + if (cached != state.liveExternalUseCache.end()) + return cached->second; + bool live = hasLiveExternalUse(value, state.oldComputeOps); + state.liveExternalUseCache[value] = live; + return live; +} std::optional getConstantFirstSliceOffset(tensor::ExtractSliceOp extract) { if (extract.getMixedOffsets().empty()) @@ -728,8 +773,8 @@ LogicalResult collectHostOutputs(MaterializerState& state) { return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); MaterializedClass& materializedClass = state.classes[state.cpuToClass.lookup(cpuIt->second)]; - for (Value output : getComputeInstanceOutputValues(instance)) { - if (!hasLiveExternalUse(output, state.oldComputeOps) || !seenOutputs.insert(output).second) + for (Value output : getComputeInstanceOutputValuesCached(state, instance)) { + if (!hasLiveExternalUseCached(state, output) || !seenOutputs.insert(output).second) continue; materializedClass.hostOutputToResultIndex[output] = materializedClass.hostOutputs.size(); @@ -903,13 +948,14 @@ Value createDim0ExtractSlice( .getResult(); } -Value createSingleDimExtractSlice(MaterializerState& state, - Location loc, - Value source, - unsigned sliceDim, - OpFoldResult offset, - ArrayRef resultShape) { +Value createStaticExtractSlice(MaterializerState& state, + Location loc, + Value source, + ArrayRef sliceOffsets, + ArrayRef resultShape) { auto sourceType = cast(source.getType()); + assert(sliceOffsets.size() == static_cast(sourceType.getRank()) && "offset rank mismatch"); + assert(resultShape.size() == static_cast(sourceType.getRank()) && "result rank mismatch"); SmallVector offsets; SmallVector sizes; SmallVector strides; @@ -918,7 +964,7 @@ Value createSingleDimExtractSlice(MaterializerState& state, strides.reserve(sourceType.getRank()); for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { - offsets.push_back(dim == static_cast(sliceDim) ? offset : OpFoldResult(state.rewriter.getIndexAttr(0))); + offsets.push_back(sliceOffsets[dim]); sizes.push_back(state.rewriter.getIndexAttr(resultShape[dim])); strides.push_back(state.rewriter.getIndexAttr(1)); } @@ -926,6 +972,32 @@ Value createSingleDimExtractSlice(MaterializerState& state, return tensor::ExtractSliceOp::create(state.rewriter, loc, source, offsets, sizes, strides).getResult(); } +Value createIndexedIndexValue(MaterializerState& state, + Operation* anchor, + ArrayRef values, + Value index, + Location loc, + std::optional preferredPeriod = std::nullopt, + bool allowExhaustiveTiledSearch = true); + +SmallVector buildProjectedFragmentOffsets(MaterializerState& state, + Operation* anchor, + const ProjectedTransferDescriptor& descriptor, + Value flatFragmentIndex, + Location loc) { + SmallVector fragmentOffsets; + fragmentOffsets.reserve(descriptor.layout.fragmentShape.size()); + for (ArrayRef dimOffsets : descriptor.fragmentOffsetsByDim) + fragmentOffsets.push_back(createIndexedIndexValue(state, + anchor, + dimOffsets, + flatFragmentIndex, + loc, + static_cast(descriptor.layout.payloadFragmentCount), + /*allowExhaustiveTiledSearch=*/false)); + return fragmentOffsets; +} + Value createDim0InsertSlice( MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) { auto fragmentType = cast(fragment.getType()); @@ -1241,45 +1313,51 @@ bool matchAffineSequence(ArrayRef values, IndexedIndexPattern& pattern) return true; } +bool matchTiledAffineSequence(ArrayRef values, IndexedIndexPattern& pattern, int64_t period) { + assert(!values.empty() && "expected at least one value"); + if (period < 2 || period > static_cast(values.size() / 2)) + return false; + + int64_t base = values.front(); + int64_t innerStep = values[1] - values[0]; + int64_t outerStep = values[period] - values[0]; + + for (auto [index, value] : llvm::enumerate(values)) { + int64_t i = static_cast(index); + int64_t expected = base + outerStep * (i / period) + innerStep * (i % period); + if (value != expected) + return false; + } + + pattern.base = base; + pattern.period = period; + pattern.innerStep = innerStep; + pattern.outerStep = outerStep; + pattern.isTiled = true; + return true; +} + bool matchTiledAffineSequence(ArrayRef values, IndexedIndexPattern& pattern) { assert(!values.empty() && "expected at least one value"); - for (int64_t period = 2; period <= static_cast(values.size() / 2); ++period) { - int64_t base = values.front(); - int64_t innerStep = values[1] - values[0]; - int64_t outerStep = values[period] - values[0]; - - bool matches = true; - for (auto [index, value] : llvm::enumerate(values)) { - int64_t i = static_cast(index); - int64_t expected = base + outerStep * (i / period) + innerStep * (i % period); - if (value != expected) { - matches = false; - break; - } - } - - if (!matches) - continue; - - pattern.base = base; - pattern.period = period; - pattern.innerStep = innerStep; - pattern.outerStep = outerStep; - pattern.isTiled = true; - return true; - } + for (int64_t period = 2; period <= static_cast(values.size() / 2); ++period) + if (matchTiledAffineSequence(values, pattern, period)) + return true; return false; } -std::optional getIndexedIndexPattern(ArrayRef values) { +std::optional getIndexedIndexPattern(ArrayRef values, + std::optional preferredPeriod = std::nullopt, + bool allowExhaustiveTiledSearch = true) { assert(!values.empty() && "expected at least one value"); IndexedIndexPattern pattern; if (matchAffineSequence(values, pattern)) return pattern; - if (matchTiledAffineSequence(values, pattern)) + if (preferredPeriod && matchTiledAffineSequence(values, pattern, *preferredPeriod)) + return pattern; + if (allowExhaustiveTiledSearch && values.size() <= 256 && matchTiledAffineSequence(values, pattern)) return pattern; return std::nullopt; @@ -1302,14 +1380,20 @@ Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {index}, state.func); } -Value createIndexedIndexValue( - MaterializerState& state, Operation* anchor, ArrayRef values, Value index, Location loc) { +Value createIndexedIndexValue(MaterializerState& state, + Operation* anchor, + ArrayRef values, + Value index, + Location loc, + std::optional preferredPeriod, + bool allowExhaustiveTiledSearch) { assert(!values.empty() && "expected at least one indexed value"); if (allEqual(values)) return getOrCreateIndexConstant(state.constantFolder, anchor, values.front()); - if (std::optional pattern = getIndexedIndexPattern(values)) + if (std::optional pattern = + getIndexedIndexPattern(values, preferredPeriod, allowExhaustiveTiledSearch)) return createAffineIndexValue(state, *pattern, index, loc); Value table = createIndexTensorConstant(state, anchor, values); @@ -1325,7 +1409,7 @@ Value createIndexedIndexValue( for (int32_t value : values) widened.push_back(value); - return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc); + return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, std::nullopt, true); } Value createIndexedChannelId( @@ -1333,16 +1417,46 @@ Value createIndexedChannelId( return createIndexedIndexValue(state, anchor, ArrayRef(messages.channelIds), index, loc); } +Value createIndexedChannelId(MaterializerState& state, + Operation* anchor, + const MessageVector& messages, + Value index, + Location loc, + std::optional preferredPeriod) { + return createIndexedIndexValue( + state, anchor, ArrayRef(messages.channelIds), index, loc, preferredPeriod, true); +} + Value createIndexedSourceCoreId( MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { return createIndexedIndexValue(state, anchor, ArrayRef(messages.sourceCoreIds), index, loc); } +Value createIndexedSourceCoreId(MaterializerState& state, + Operation* anchor, + const MessageVector& messages, + Value index, + Location loc, + std::optional preferredPeriod) { + SmallVector widened(messages.sourceCoreIds.begin(), messages.sourceCoreIds.end()); + return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, preferredPeriod, true); +} + Value createIndexedTargetCoreId( MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { return createIndexedIndexValue(state, anchor, ArrayRef(messages.targetCoreIds), index, loc); } +Value createIndexedTargetCoreId(MaterializerState& state, + Operation* anchor, + const MessageVector& messages, + Value index, + Location loc, + std::optional preferredPeriod) { + SmallVector widened(messages.targetCoreIds.begin(), messages.targetCoreIds.end()); + return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, preferredPeriod, true); +} + Value createLaneIndexedIndexValue(MaterializerState& state, MaterializedClass& materializedClass, ArrayRef values, @@ -1476,12 +1590,73 @@ LogicalResult collectProducerDestinations(MaterializerState& state) { }); } -static bool isLaneProjectedOffsetValue(Value value, Value expected, bool& usesExpected) { - if (value == expected) { - usesExpected = true; +bool isStaticSliceInBounds(ArrayRef offsets, RankedTensorType sourceType, RankedTensorType fragmentType) { + if (offsets.size() != static_cast(sourceType.getRank()) + || offsets.size() != static_cast(fragmentType.getRank())) + return false; + + for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { + int64_t offset = offsets[dim]; + if (offset < 0) + return false; + + int64_t sourceDimSize = sourceType.getDimSize(dim); + int64_t fragmentDimSize = fragmentType.getDimSize(dim); + if (fragmentDimSize < 0 || sourceDimSize < 0 || fragmentDimSize > sourceDimSize) + return false; + if (offset > sourceDimSize - fragmentDimSize) + return false; + } + + return true; +} + +static std::optional getStaticForTripCount(scf::ForOp loop) { + std::optional lowerBound = matchConstantIndexValue(loop.getLowerBound()); + std::optional upperBound = matchConstantIndexValue(loop.getUpperBound()); + std::optional step = matchConstantIndexValue(loop.getStep()); + if (!lowerBound || !upperBound || !step || *step <= 0 || *upperBound < *lowerBound) + return std::nullopt; + + int64_t distance = *upperBound - *lowerBound; + return (distance + *step - 1) / *step; +} + +static SmallVector collectEnclosingStaticProjectedLoops(Operation* op) { + SmallVector loops; + SmallVector reversedLoops; + for (Operation* current = op->getParentOp(); current; current = current->getParentOp()) + if (auto loop = dyn_cast(current)) + reversedLoops.push_back(loop); + + for (scf::ForOp loop : llvm::reverse(reversedLoops)) { + std::optional lowerBound = matchConstantIndexValue(loop.getLowerBound()); + std::optional step = matchConstantIndexValue(loop.getStep()); + std::optional tripCount = getStaticForTripCount(loop); + if (!lowerBound || !step || !tripCount) + return {}; + loops.push_back(StaticProjectedLoopInfo {.iv = cast(loop.getInductionVar()), + .lowerBound = *lowerBound, + .step = *step, + .tripCount = *tripCount}); + } + return loops; +} + +static bool +isProjectedOffsetValue(Value value, Value laneArg, ArrayRef loops, bool& usesDynamicBinding) { + if (value == laneArg) { + usesDynamicBinding = true; return true; } + for (const StaticProjectedLoopInfo& loop : loops) { + if (value == loop.iv) { + usesDynamicBinding = true; + return true; + } + } + if (matchPattern(value, m_Constant())) return true; @@ -1489,104 +1664,183 @@ static bool isLaneProjectedOffsetValue(Value value, Value expected, bool& usesEx if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) return false; - bool nestedUsesExpected = false; + bool nestedUsesDynamicBinding = false; for (Value operand : affineApply.getMapOperands()) { - bool operandUsesExpected = false; - if (!isLaneProjectedOffsetValue(operand, expected, operandUsesExpected)) + bool operandUsesDynamicBinding = false; + if (!isProjectedOffsetValue(operand, laneArg, loops, operandUsesDynamicBinding)) return false; - nestedUsesExpected = nestedUsesExpected || operandUsesExpected; + nestedUsesDynamicBinding = nestedUsesDynamicBinding || operandUsesDynamicBinding; } - usesExpected = usesExpected || nestedUsesExpected; - return nestedUsesExpected; + usesDynamicBinding = usesDynamicBinding || nestedUsesDynamicBinding; + return true; } -bool isValueOffset(OpFoldResult offset, Value expected) { - auto value = dyn_cast(offset); - if (!value) - return false; +static std::optional getConstantIndex(OpFoldResult value); - bool usesExpected = false; - return isLaneProjectedOffsetValue(value, expected, usesExpected) && usesExpected; -} - -bool isStaticIndexAttr(OpFoldResult value, int64_t expected) { - auto attr = dyn_cast(value); - if (!attr) - return false; - - auto intAttr = dyn_cast(attr); - return intAttr && intAttr.getInt() == expected; -} - -bool isStaticSliceInBounds(int64_t offset, - RankedTensorType sourceType, - RankedTensorType fragmentType, - unsigned sliceDim) { - if (offset < 0) - return false; - if (sliceDim >= static_cast(sourceType.getRank()) - || sliceDim >= static_cast(fragmentType.getRank())) - return false; - - int64_t sourceDimSize = sourceType.getDimSize(sliceDim); - int64_t fragmentDimSize = fragmentType.getDimSize(sliceDim); - if (fragmentDimSize < 0 || sourceDimSize < 0 || fragmentDimSize > sourceDimSize) - return false; - - return offset <= sourceDimSize - fragmentDimSize; -} - -std::optional getLaneProjectedDim(ArrayRef offsets, Value laneArg) { - std::optional projectedDim; - for (auto [dim, offset] : llvm::enumerate(offsets)) { - if (!isValueOffset(offset, laneArg)) - continue; - - if (projectedDim) - return std::nullopt; - projectedDim = static_cast(dim); +static unsigned getProjectedFragmentsPerLogicalSlot(ArrayRef loopTripCounts) { + unsigned fragmentsPerLogicalSlot = 1; + for (int64_t tripCount : loopTripCounts) { + assert(tripCount > 0 && "projected loop trip counts must be positive"); + fragmentsPerLogicalSlot *= static_cast(tripCount); } - - return projectedDim; + return fragmentsPerLogicalSlot; } -static FailureOr evaluateProjectedOffsetForLane(OpFoldResult value, Value laneArg, uint32_t lane) { - if (std::optional constant = matchConstantIndexValue(value)) +LogicalResult verifyProjectedFragmentLayout(Operation* anchor, const ProjectedFragmentLayout& layout) { + if (!layout.fragmentType || layout.fragmentShape.empty()) + return anchor->emitError("projected fragment layout is missing fragment type metadata"); + if (layout.fragmentShape.size() != static_cast(layout.fragmentType.getRank())) + return anchor->emitError("projected fragment layout rank does not match fragment type"); + if (layout.payloadFragmentCount == 0 || layout.fragmentsPerLogicalSlot == 0) + return anchor->emitError("projected fragment layout has an invalid fragment count"); + if (layout.payloadFragmentCount % layout.fragmentsPerLogicalSlot != 0) + return anchor->emitError("projected fragment layout payload fragment count is incompatible with logical slots"); + return success(); +} + +FailureOr +getProjectedPayloadType(Operation* anchor, RankedTensorType fragmentType, unsigned payloadFragmentCount) { + if (failed( + verifyPackableFragmentType(anchor, fragmentType, payloadFragmentCount, "cannot create projected payload type"))) + return failure(); + return getPackedBatchTensorType(fragmentType, payloadFragmentCount); +} + +SmallVector, 4> +buildProjectedFragmentOffsetsByDim(ArrayRef> fragmentOffsets, size_t rank) { + SmallVector, 4> fragmentOffsetsByDim(rank); + for (ArrayRef offsets : fragmentOffsets) { + assert(offsets.size() == rank && "projected offset rank mismatch"); + for (size_t dim = 0; dim < rank; ++dim) + fragmentOffsetsByDim[dim].push_back(offsets[dim]); + } + return fragmentOffsetsByDim; +} + +LogicalResult verifyProjectedTransferDescriptor(Operation* anchor, const ProjectedTransferDescriptor& descriptor) { + if (failed(verifyProjectedFragmentLayout(anchor, descriptor.layout))) + return failure(); + if (!descriptor.payloadType) + return anchor->emitError("projected transfer descriptor is missing payload type"); + if (descriptor.fragmentOffsets.empty()) + return anchor->emitError("projected transfer descriptor expected at least one fragment offset"); + if (descriptor.fragmentOffsetsByDim.size() != descriptor.layout.fragmentShape.size()) + return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent"); + for (ArrayRef dimOffsets : descriptor.fragmentOffsetsByDim) + if (dimOffsets.size() != descriptor.fragmentOffsets.size()) + return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent"); + for (ArrayRef offsets : descriptor.fragmentOffsets) + if (offsets.size() != descriptor.layout.fragmentShape.size()) + return anchor->emitError("projected transfer offset rank does not match fragment rank"); + return success(); +} + +LogicalResult verifyProjectedSendDescriptor(Operation* anchor, + const ProjectedTransferDescriptor& descriptor, + const MessageVector& messages) { + if (failed(verifyProjectedTransferDescriptor(anchor, descriptor))) + return failure(); + if (messages.size() * descriptor.layout.payloadFragmentCount != descriptor.fragmentOffsets.size()) + return anchor->emitError("projected send descriptor metadata is inconsistent"); + return success(); +} + +LogicalResult finalizeProjectedTransferDescriptor(Operation* anchor, ProjectedTransferDescriptor& descriptor) { + descriptor.fragmentOffsetsByDim = + buildProjectedFragmentOffsetsByDim(descriptor.fragmentOffsets, descriptor.layout.fragmentShape.size()); + + FailureOr payloadType = + getProjectedPayloadType(anchor, descriptor.layout.fragmentType, descriptor.layout.payloadFragmentCount); + if (failed(payloadType)) + return failure(); + if (descriptor.payloadType && descriptor.payloadType != *payloadType) + return anchor->emitError("projected transfer descriptor payload type does not match projected layout"); + descriptor.payloadType = *payloadType; + + return verifyProjectedTransferDescriptor(anchor, descriptor); +} + +static FailureOr evaluateProjectedOffsetValue(OpFoldResult value, + Value laneArg, + uint32_t lane, + ArrayRef loops, + ArrayRef loopIterationIndices) { + if (std::optional constant = getConstantIndex(value)) return *constant; - Value current = cast(value); + Value current = dyn_cast(value); + if (!current) + return failure(); if (current == laneArg) return static_cast(lane); - if (auto affineApply = current.getDefiningOp()) - return evaluateAffineApply(affineApply, - [&](Value operand) { return evaluateProjectedOffsetForLane(operand, laneArg, lane); }); + for (auto [index, loop] : llvm::enumerate(loops)) { + if (current != loop.iv) + continue; + if (index >= loopIterationIndices.size()) + return failure(); + return loop.lowerBound + loopIterationIndices[index] * loop.step; + } + + if (auto affineApply = current.getDefiningOp()) { + return evaluateAffineApply(affineApply, [&](Value operand) { + return evaluateProjectedOffsetValue(operand, laneArg, lane, loops, loopIterationIndices); + }); + } return failure(); } -std::optional matchSimpleLaneProjectedInput(SpatComputeBatch batch, unsigned inputIndex) { +static std::optional getConstantIndex(OpFoldResult value) { + if (auto attr = dyn_cast(value)) { + auto intAttr = dyn_cast(attr); + if (!intAttr) + return std::nullopt; + return intAttr.getInt(); + } + + Value operand = dyn_cast(value); + if (!operand) + return std::nullopt; + + if (auto constantIndex = operand.getDefiningOp()) + return constantIndex.value(); + + APInt apInt; + if (matchPattern(operand, m_ConstantInt(&apInt))) { + if (apInt.isNegative()) + return std::nullopt; + return static_cast(apInt.getSExtValue()); + } + + return std::nullopt; +} + +static std::optional matchAffineProjectedInputSlice(SpatComputeBatch batch, + unsigned inputIndex) { + const auto fail = [&](StringRef) -> std::optional { return std::nullopt; }; + std::optional inputArg = batch.getInputArgument(inputIndex); std::optional laneArg = batch.getLaneArgument(); if (!inputArg || !laneArg) - return std::nullopt; + return fail("missing-input-or-lane-arg"); if (!inputArg->hasOneUse()) - return std::nullopt; + return fail("input-arg-not-one-use"); Operation* user = *inputArg->getUsers().begin(); auto extract = dyn_cast(user); if (!extract || extract.getSource() != *inputArg) - return std::nullopt; + return fail("input-user-is-not-direct-extract-slice"); auto inputType = dyn_cast(inputArg->getType()); auto fragmentType = dyn_cast(extract.getResult().getType()); if (!inputType || !fragmentType || !inputType.hasStaticShape() || !fragmentType.hasStaticShape()) - return std::nullopt; + return fail("non-static-ranked-input-or-fragment"); if (inputType.getRank() == 0 || inputType.getRank() != fragmentType.getRank()) - return std::nullopt; + return fail("rank-mismatch-or-rank-zero"); SmallVector offsets = extract.getMixedOffsets(); SmallVector sizes = extract.getMixedSizes(); @@ -1595,34 +1849,63 @@ std::optional matchSimpleLaneProjectedInput(SpatComputeB if (offsets.size() != static_cast(inputType.getRank()) || sizes.size() != static_cast(inputType.getRank()) || strides.size() != static_cast(inputType.getRank())) - return std::nullopt; + return fail("slice-rank-mismatch"); - std::optional projectedDim = getLaneProjectedDim(offsets, *laneArg); - if (!projectedDim) - return std::nullopt; + SmallVector loops = collectEnclosingStaticProjectedLoops(extract.getOperation()); + if (extract->getParentOfType() && loops.empty()) + return fail("unsupported-enclosing-loop"); - for (int64_t dim = 0; dim < inputType.getRank(); ++dim) { - if (dim == static_cast(*projectedDim)) { - if (!isStaticIndexAttr(sizes[dim], 1) || !isStaticIndexAttr(strides[dim], 1)) + bool hasDynamicProjection = false; + for (auto [dim, offset] : llvm::enumerate(offsets)) { + bool usesDynamicBinding = false; + if (auto value = dyn_cast(offset)) { + if (!isProjectedOffsetValue(value, *laneArg, loops, usesDynamicBinding)) return std::nullopt; - continue; } - - if (!isStaticIndexAttr(offsets[dim], 0)) + else if (!isa(offset)) return std::nullopt; - if (!isStaticIndexAttr(sizes[dim], inputType.getDimSize(dim))) + if (std::optional stride = getConstantIndex(strides[dim]); !stride || *stride != 1) return std::nullopt; - if (!isStaticIndexAttr(strides[dim], 1)) + std::optional size = getConstantIndex(sizes[dim]); + if (!size || *size != fragmentType.getDimSize(dim)) return std::nullopt; + hasDynamicProjection = hasDynamicProjection || usesDynamicBinding; } - for (int64_t dim = 0; dim < inputType.getRank(); ++dim) { - int64_t expectedDimSize = dim == *projectedDim ? 1 : inputType.getDimSize(dim); - if (fragmentType.getDimSize(dim) != expectedDimSize) + if (!hasDynamicProjection) + return fail("no-dynamic-projection"); + + for (int64_t dim = 0; dim < inputType.getRank(); ++dim) + if (fragmentType.getDimSize(dim) <= 0 || fragmentType.getDimSize(dim) > inputType.getDimSize(dim)) return std::nullopt; + + AffineProjectedInputSliceMatch match; + match.extract = extract; + match.sourceType = inputType; + match.fragmentType = fragmentType; + match.offsets.assign(offsets.begin(), offsets.end()); + match.fragmentShape.assign(fragmentType.getShape().begin(), fragmentType.getShape().end()); + match.loops = std::move(loops); + return match; +} + +std::optional +getProjectedInputSliceMatch(MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex) { + ProjectedBatchInputKey key {batch.getOperation(), inputIndex}; + auto cached = state.projectedInputMatches.find(key); + if (cached != state.projectedInputMatches.end()) + return cached->second; + if (state.nonProjectedInputs.contains(key)) + return std::nullopt; + + std::optional match = matchAffineProjectedInputSlice(batch, inputIndex); + if (!match) { + state.nonProjectedInputs.insert(key); + return std::nullopt; } - return extract; + state.projectedInputMatches.insert({key, *match}); + return match; } LogicalResult collectProjectedTransfers(MaterializerState& state) { @@ -1630,13 +1913,60 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { ProjectedBatchInputKey inputKey; Operation* extractOp = nullptr; RankedTensorType fragmentType; - unsigned sourceProjectedDim = 0; - SmallVector, 8> offsetsByLane; + SmallVector fragmentShape; + SmallVector, 16>, 8> fragmentOffsetsByLane; + SmallVector loopLowerBounds; + SmallVector loopSteps; + SmallVector loopTripCounts; bool invalid = false; }; DenseMap, ProducerKeyInfo> pending; + const auto appendEvaluatedFragments = [&](PendingProjectedTransferDescriptor& descriptor, + unsigned targetLane, + const AffineProjectedInputSliceMatch& match, + Value laneArg, + uint32_t lane) -> LogicalResult { + SmallVector loopIterationIndices; + loopIterationIndices.resize(match.loops.size(), 0); + + const auto appendOneFragment = [&]() -> LogicalResult { + SmallVector evaluatedOffsets; + evaluatedOffsets.reserve(match.offsets.size()); + for (OpFoldResult offset : match.offsets) { + FailureOr evaluated = + evaluateProjectedOffsetValue(offset, laneArg, lane, match.loops, loopIterationIndices); + if (failed(evaluated)) + return failure(); + evaluatedOffsets.push_back(*evaluated); + } + + if (!isStaticSliceInBounds(evaluatedOffsets, match.sourceType, match.fragmentType)) + return failure(); + + descriptor.fragmentOffsetsByLane[targetLane].push_back(std::move(evaluatedOffsets)); + return success(); + }; + + if (match.loops.empty()) + return appendOneFragment(); + + const auto recurse = [&](auto&& self, size_t loopIndex) -> LogicalResult { + if (loopIndex == match.loops.size()) + return appendOneFragment(); + + for (int64_t iteration = 0; iteration < match.loops[loopIndex].tripCount; ++iteration) { + loopIterationIndices[loopIndex] = iteration; + if (failed(self(self, loopIndex + 1))) + return failure(); + } + return success(); + }; + + return recurse(recurse, 0); + }; + if (failed(forEachLogicalConsumerInMaterializationOrder( state, [&](CpuId cpu, @@ -1649,14 +1979,13 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { return success(); MaterializedClass& targetClass = state.classes[targetClassId]; - if (!targetClass.isBatch) - return success(); - - auto targetLaneIt = targetClass.cpuToLane.find(cpu); - if (targetLaneIt == targetClass.cpuToLane.end()) - return consumer.op->emitError("projected transfer collection could not recover target lane"); - - unsigned targetLane = targetLaneIt->second; + unsigned targetLane = 0; + if (targetClass.isBatch) { + auto targetLaneIt = targetClass.cpuToLane.find(cpu); + if (targetLaneIt == targetClass.cpuToLane.end()) + return consumer.op->emitError("projected transfer collection could not recover target lane"); + targetLane = targetLaneIt->second; + } for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) { SmallVector producers = collectProducerKeysForDestinations(input, logicalConsumer); @@ -1673,48 +2002,57 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { if (sourceClassId == targetClassId) continue; - std::optional extract = - matchSimpleLaneProjectedInput(batch, static_cast(inputIndex)); - if (!extract) - continue; - - auto inputType = cast(extract->getSource().getType()); - auto fragmentType = cast((*extract).getResult().getType()); - SmallVector offsets = extract->getMixedOffsets(); - std::optional sourceProjectedDim = getLaneProjectedDim(offsets, *batch.getLaneArgument()); - if (!sourceProjectedDim) + std::optional match = + getProjectedInputSliceMatch(state, batch, static_cast(inputIndex)); + if (!match) continue; PendingProjectedTransferDescriptor& descriptor = pending[producer][targetClassId]; - if (descriptor.offsetsByLane.empty()) { + if (descriptor.fragmentOffsetsByLane.empty()) { descriptor.inputKey = {batch.getOperation(), static_cast(inputIndex)}; - descriptor.extractOp = extract->getOperation(); - descriptor.fragmentType = fragmentType; - descriptor.sourceProjectedDim = *sourceProjectedDim; - descriptor.offsetsByLane.resize(targetClass.cpus.size()); + descriptor.extractOp = match->extract.getOperation(); + descriptor.fragmentType = match->fragmentType; + descriptor.fragmentShape = match->fragmentShape; + descriptor.fragmentOffsetsByLane.resize(targetClass.isBatch ? targetClass.cpus.size() : 1); + descriptor.loopLowerBounds.reserve(match->loops.size()); + descriptor.loopSteps.reserve(match->loops.size()); + descriptor.loopTripCounts.reserve(match->loops.size()); + for (const StaticProjectedLoopInfo& loop : match->loops) { + descriptor.loopLowerBounds.push_back(loop.lowerBound); + descriptor.loopSteps.push_back(loop.step); + descriptor.loopTripCounts.push_back(loop.tripCount); + } } ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast(inputIndex)}; - if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != extract->getOperation() - || descriptor.fragmentType != fragmentType || descriptor.sourceProjectedDim != *sourceProjectedDim) { + if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != match->extract.getOperation() + || descriptor.fragmentType != match->fragmentType || descriptor.fragmentShape != match->fragmentShape + || descriptor.loopLowerBounds.size() != match->loops.size()) { + descriptor.invalid = true; + continue; + } + for (auto [index, loop] : llvm::enumerate(match->loops)) { + if (descriptor.loopLowerBounds[index] != loop.lowerBound || descriptor.loopSteps[index] != loop.step + || descriptor.loopTripCounts[index] != loop.tripCount) { + descriptor.invalid = true; + break; + } + } + if (descriptor.invalid) + continue; + + if (targetLane >= descriptor.fragmentOffsetsByLane.size()) { descriptor.invalid = true; continue; } - if (targetLane >= descriptor.offsetsByLane.size()) { - descriptor.invalid = true; - continue; - } - - FailureOr offset = evaluateProjectedOffsetForLane( - offsets[*sourceProjectedDim], *batch.getLaneArgument(), logicalConsumer.laneStart); - if (failed(offset) || !isStaticSliceInBounds(*offset, inputType, fragmentType, *sourceProjectedDim)) { + if (failed(appendEvaluatedFragments( + descriptor, targetLane, *match, *batch.getLaneArgument(), logicalConsumer.laneStart))) { descriptor.invalid = true; continue; } (void) logicalSlot; - descriptor.offsetsByLane[targetLane].push_back(*offset); } return success(); @@ -1729,38 +2067,52 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { if (pendingDescriptor.invalid) continue; - if (pendingDescriptor.offsetsByLane.empty()) + if (pendingDescriptor.fragmentOffsetsByLane.empty()) continue; - unsigned fragmentsPerLane = pendingDescriptor.offsetsByLane.front().size(); - if (fragmentsPerLane == 0) - continue; - - bool uniform = true; - for (ArrayRef laneOffsets : pendingDescriptor.offsetsByLane) { - if (laneOffsets.size() != fragmentsPerLane) { - uniform = false; - break; - } - } - if (!uniform) - continue; - - SmallVector payloadShape(pendingDescriptor.fragmentType.getShape()); - payloadShape[0] *= static_cast(fragmentsPerLane); - RankedTensorType payloadType = RankedTensorType::get( - payloadShape, pendingDescriptor.fragmentType.getElementType(), pendingDescriptor.fragmentType.getEncoding()); - + MaterializedClass& targetClass = state.classes[targetClassId]; ProjectedTransferDescriptor descriptor; descriptor.inputKey = pendingDescriptor.inputKey; descriptor.extractOp = pendingDescriptor.extractOp; - descriptor.fragmentType = pendingDescriptor.fragmentType; - descriptor.payloadType = payloadType; - descriptor.sourceProjectedDim = pendingDescriptor.sourceProjectedDim; - descriptor.fragmentsPerLane = fragmentsPerLane; - descriptor.laneMajorProjectedOffsets.reserve(pendingDescriptor.offsetsByLane.size() * fragmentsPerLane); - for (ArrayRef laneOffsets : pendingDescriptor.offsetsByLane) - llvm::append_range(descriptor.laneMajorProjectedOffsets, laneOffsets); + descriptor.layout.fragmentType = pendingDescriptor.fragmentType; + descriptor.layout.fragmentShape = pendingDescriptor.fragmentShape; + descriptor.layout.loopLowerBounds = pendingDescriptor.loopLowerBounds; + descriptor.layout.loopSteps = pendingDescriptor.loopSteps; + descriptor.layout.loopTripCounts = pendingDescriptor.loopTripCounts; + descriptor.layout.fragmentsPerLogicalSlot = getProjectedFragmentsPerLogicalSlot(descriptor.layout.loopTripCounts); + if (targetClass.isBatch) { + unsigned payloadFragmentCount = pendingDescriptor.fragmentOffsetsByLane.front().size(); + if (payloadFragmentCount == 0) + continue; + + bool uniform = true; + for (ArrayRef> laneFragments : pendingDescriptor.fragmentOffsetsByLane) { + if (laneFragments.size() != payloadFragmentCount) { + uniform = false; + break; + } + } + if (!uniform) + continue; + + descriptor.layout.payloadFragmentCount = payloadFragmentCount; + descriptor.fragmentOffsets.reserve(pendingDescriptor.fragmentOffsetsByLane.size() * payloadFragmentCount); + for (ArrayRef> laneFragments : pendingDescriptor.fragmentOffsetsByLane) + llvm::append_range(descriptor.fragmentOffsets, laneFragments); + } + else { + if (pendingDescriptor.fragmentOffsetsByLane.size() != 1) + return targetClass.op->emitError("scalar projected transfer descriptor expected one local offset stream"); + if (pendingDescriptor.fragmentOffsetsByLane.front().empty()) + continue; + + descriptor.layout.payloadFragmentCount = pendingDescriptor.fragmentOffsetsByLane.front().size(); + llvm::append_range(descriptor.fragmentOffsets, pendingDescriptor.fragmentOffsetsByLane.front()); + if (descriptor.fragmentOffsets.size() != descriptor.layout.payloadFragmentCount) + return targetClass.op->emitError("scalar projected transfer offset count does not match the local run"); + } + if (failed(finalizeProjectedTransferDescriptor(targetClass.op, descriptor))) + return failure(); state.projectedTransfers[producer][targetClassId] = std::move(descriptor); } @@ -1769,6 +2121,66 @@ LogicalResult collectProjectedTransfers(MaterializerState& state) { return success(); } +static std::optional +collectScalarTargetProjectedDescriptor(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef keys, + bool requirePackedRunOffsetCountMatch) { + assert(!targetClass.isBatch && "scalar target projected descriptor helper expects a scalar class"); + + std::optional combined; + for (ProducerKey key : keys) { + auto producerIt = state.projectedTransfers.find(key); + if (producerIt == state.projectedTransfers.end()) + return std::nullopt; + + auto descriptorIt = producerIt->second.find(targetClass.id); + if (descriptorIt == producerIt->second.end()) + return std::nullopt; + + const ProjectedTransferDescriptor& descriptor = descriptorIt->second; + if (descriptor.fragmentOffsets.empty()) + return std::nullopt; + if (descriptor.layout.payloadFragmentCount == 0 || descriptor.layout.fragmentsPerLogicalSlot == 0) + return std::nullopt; + if (descriptor.fragmentOffsets.size() != descriptor.layout.payloadFragmentCount) + return std::nullopt; + if (descriptor.layout.payloadFragmentCount % descriptor.layout.fragmentsPerLogicalSlot != 0) + return std::nullopt; + + if (!combined) { + combined = descriptor; + continue; + } + + if (!(combined->inputKey == descriptor.inputKey) || combined->extractOp != descriptor.extractOp + || combined->layout.fragmentType != descriptor.layout.fragmentType + || combined->layout.fragmentShape != descriptor.layout.fragmentShape + || combined->layout.loopLowerBounds != descriptor.layout.loopLowerBounds + || combined->layout.loopSteps != descriptor.layout.loopSteps + || combined->layout.loopTripCounts != descriptor.layout.loopTripCounts + || combined->layout.fragmentsPerLogicalSlot != descriptor.layout.fragmentsPerLogicalSlot) + return std::nullopt; + + combined->layout.payloadFragmentCount += descriptor.layout.payloadFragmentCount; + llvm::append_range(combined->fragmentOffsets, descriptor.fragmentOffsets); + } + + if (!combined) + return std::nullopt; + + if (combined->fragmentOffsets.size() != combined->layout.payloadFragmentCount) + return std::nullopt; + + if (requirePackedRunOffsetCountMatch) { + if (combined->layout.payloadFragmentCount != keys.size() * combined->layout.fragmentsPerLogicalSlot) + return std::nullopt; + } + if (failed(finalizeProjectedTransferDescriptor(targetClass.op, *combined))) + return std::nullopt; + return combined; +} + bool haveSameDestinationClasses(MaterializerState& state, ArrayRef keys) { if (keys.empty()) return true; @@ -1853,16 +2265,19 @@ FailureOr buildProjectedPackedPayload(MaterializerState& state, Operation* anchor, Value fullPayload, const ProjectedTransferDescriptor& descriptor, - Value laneIndex, + Value messageIndex, Location loc) { - assert(descriptor.fragmentsPerLane > 1 && "use direct fragment path for single-fragment projection"); + if (failed(verifyProjectedTransferDescriptor(anchor, descriptor))) + return failure(); + if (descriptor.layout.payloadFragmentCount == 1) + return anchor->emitError("projected packed payload builder expects a packed payload"); Value init = tensor::EmptyOp::create( state.rewriter, loc, descriptor.payloadType.getShape(), descriptor.payloadType.getElementType()) .getResult(); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, anchor, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.fragmentsPerLane); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.layout.payloadFragmentCount); Value step = getOrCreateIndexConstant(state.constantFolder, anchor, 1); auto loop = buildNormalizedScfFor( @@ -1874,16 +2289,18 @@ FailureOr buildProjectedPackedPayload(MaterializerState& state, ValueRange {init}, [&](OpBuilder&, Location, Value fragmentIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { Value acc = iterArgs.front(); - Value fragmentsPerLane = getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.fragmentsPerLane); - Value flatBase = arith::MulIOp::create(state.rewriter, loc, laneIndex, fragmentsPerLane).getResult(); + Value payloadFragmentCount = + getOrCreateIndexConstant(state.constantFolder, anchor, descriptor.layout.payloadFragmentCount); + Value flatBase = arith::MulIOp::create(state.rewriter, loc, messageIndex, payloadFragmentCount).getResult(); Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult(); - Value sourceOffset = createIndexedIndexValue(state, anchor, descriptor.laneMajorProjectedOffsets, flatIndex, loc); - Value fragment = createSingleDimExtractSlice( - state, loc, fullPayload, descriptor.sourceProjectedDim, sourceOffset, descriptor.fragmentType.getShape()); + SmallVector fragmentOffsets = + buildProjectedFragmentOffsets(state, anchor, descriptor, flatIndex, loc); + Value fragment = + createStaticExtractSlice(state, loc, fullPayload, fragmentOffsets, descriptor.layout.fragmentShape); Value packedOffset = - scaleIndexByDim0Size(state, anchor, fragmentIndex, descriptor.fragmentType.getDimSize(0), loc); + scaleIndexByDim0Size(state, anchor, fragmentIndex, descriptor.layout.fragmentType.getDimSize(0), loc); Value next = createDim0InsertSlice(state, loc, fragment, acc, packedOffset); yielded.push_back(next); return success(); @@ -1893,6 +2310,24 @@ FailureOr buildProjectedPackedPayload(MaterializerState& state, return loop->results.front(); } +FailureOr buildProjectedPayloadForMessage(MaterializerState& state, + Operation* anchor, + Value fullPayload, + const ProjectedTransferDescriptor& descriptor, + Value messageIndex, + Location loc) { + if (failed(verifyProjectedTransferDescriptor(anchor, descriptor))) + return failure(); + + if (descriptor.layout.payloadFragmentCount == 1) { + SmallVector fragmentOffsets = + buildProjectedFragmentOffsets(state, anchor, descriptor, messageIndex, loc); + return createStaticExtractSlice(state, loc, fullPayload, fragmentOffsets, descriptor.layout.fragmentShape); + } + + return buildProjectedPackedPayload(state, anchor, fullPayload, descriptor, messageIndex, loc); +} + LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, MaterializedClass& sourceClass, Value payload, @@ -1901,8 +2336,8 @@ LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, Location loc) { assert(!sourceClass.isBatch && "projected scalar send expects scalar source class"); assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); - assert(messages.size() * descriptor.fragmentsPerLane == descriptor.laneMajorProjectedOffsets.size() - && "projected send lane count mismatch"); + if (failed(verifyProjectedSendDescriptor(sourceClass.op, descriptor, messages))) + return failure(); state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); @@ -1910,22 +2345,12 @@ LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.channelIds.front()); Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.sourceCoreIds.front()); Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.targetCoreIds.front()); - Value laneIndex = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); - Value sendPayload; - if (descriptor.fragmentsPerLane == 1) { - Value offset = - getOrCreateIndexConstant(state.constantFolder, sourceClass.op, descriptor.laneMajorProjectedOffsets.front()); - sendPayload = createSingleDimExtractSlice( - state, loc, payload, descriptor.sourceProjectedDim, offset, descriptor.fragmentType.getShape()); - } - else { - auto packedPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, laneIndex, loc); - if (failed(packedPayload)) - return failure(); - sendPayload = *packedPayload; - } - - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload); + Value messageIndex = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); + FailureOr sendPayload = + buildProjectedPayloadForMessage(state, sourceClass.op, payload, descriptor, messageIndex, loc); + if (failed(sendPayload)) + return failure(); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, *sendPayload); return success(); } @@ -1945,21 +2370,11 @@ LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, Value channelId = createIndexedChannelId(state, sourceClass.op, messages, index, loc); Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, messages, index, loc); Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, messages, index, loc); - - Value sendPayload; - if (descriptor.fragmentsPerLane == 1) { - Value offset = createIndexedIndexValue(state, sourceClass.op, descriptor.laneMajorProjectedOffsets, index, loc); - sendPayload = createSingleDimExtractSlice( - state, loc, payload, descriptor.sourceProjectedDim, offset, descriptor.fragmentType.getShape()); - } - else { - auto packedPayload = buildProjectedPackedPayload(state, sourceClass.op, payload, descriptor, index, loc); - if (failed(packedPayload)) - return failure(); - sendPayload = *packedPayload; - } - - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, sendPayload); + FailureOr sendPayload = + buildProjectedPayloadForMessage(state, sourceClass.op, payload, descriptor, index, loc); + if (failed(sendPayload)) + return failure(); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, *sendPayload); return success(); }); if (failed(projectedSendLoop)) @@ -2112,34 +2527,28 @@ struct ScalarSourceReceivePlan { MessageVector messages; Type receiveType; Operation* projectedExtractOp = nullptr; - RankedTensorType projectedFragmentType; - unsigned projectedFragmentsPerLane = 1; + ProjectedFragmentLayout projectedLayout; }; -struct ProjectedTransferCompatibilityKey { - RankedTensorType fragmentType; - RankedTensorType payloadType; - unsigned sourceProjectedDim = 0; - unsigned fragmentsPerLane = 1; - - bool operator==(const ProjectedTransferCompatibilityKey& other) const { - return fragmentType == other.fragmentType && payloadType == other.payloadType - && sourceProjectedDim == other.sourceProjectedDim && fragmentsPerLane == other.fragmentsPerLane; - } -}; - -struct ScalarSourceSendGroup { +struct ProjectedScalarSendGroup { MessageVector messages; - std::optional projectedKey; - SmallVector projectedOffsets; + ProjectedTransferDescriptor descriptor; }; struct ScalarSourceFanoutPlan { SmallVector receivePlans; - std::optional ordinarySendGroup; - SmallVector projectedSendGroups; + std::optional ordinaryMessages; + SmallVector projectedSendGroups; }; +bool hasSameProjectedSendCompatibility(const ProjectedTransferDescriptor& lhs, const ProjectedTransferDescriptor& rhs) { + return lhs.layout.fragmentType == rhs.layout.fragmentType && lhs.layout.fragmentShape == rhs.layout.fragmentShape + && lhs.layout.fragmentsPerLogicalSlot == rhs.layout.fragmentsPerLogicalSlot + && lhs.layout.payloadFragmentCount == rhs.layout.payloadFragmentCount + && lhs.layout.loopLowerBounds == rhs.layout.loopLowerBounds && lhs.layout.loopSteps == rhs.layout.loopSteps + && lhs.layout.loopTripCounts == rhs.layout.loopTripCounts && lhs.payloadType == rhs.payloadType; +} + SmallVector collectDestinationClassesForKeys(MaterializerState& state, ArrayRef keys) { SmallVector destinations; @@ -2166,28 +2575,41 @@ FailureOr buildScalarSourceFanoutPlan(MaterializerState& ScalarSourceFanoutPlan fanoutPlan; fanoutPlan.receivePlans.reserve(destinationClasses.size()); - const auto getProjectedDescriptor = [&](ClassId destinationClass) -> const ProjectedTransferDescriptor* { - if (keys.size() != 1) - return nullptr; - + const auto getProjectedDescriptor = + [&](ClassId destinationClass) -> FailureOr> { MaterializedClass& targetClass = state.classes[destinationClass]; - if (!targetClass.isBatch) - return nullptr; + if (!targetClass.isBatch) { + bool hasAnyProjectedDescriptor = llvm::any_of(keys, [&](ProducerKey key) { + auto producerIt = state.projectedTransfers.find(key); + return producerIt != state.projectedTransfers.end() && producerIt->second.count(destinationClass) != 0; + }); + + std::optional descriptor = collectScalarTargetProjectedDescriptor( + state, targetClass, keys, /*requirePackedRunOffsetCountMatch=*/keys.size() > 1); + if (hasAnyProjectedDescriptor && !descriptor) + return targetClass.op->emitError("incomplete scalar projected transfer descriptor for local run"); + return descriptor; + } + + if (keys.size() != 1) + return std::optional {}; auto producerIt = state.projectedTransfers.find(keys.front()); if (producerIt == state.projectedTransfers.end()) - return nullptr; + return std::optional {}; auto descriptorIt = producerIt->second.find(destinationClass); if (descriptorIt == producerIt->second.end()) - return nullptr; + return std::optional {}; const ProjectedTransferDescriptor& descriptor = descriptorIt->second; - if (descriptor.laneMajorProjectedOffsets.size() - != targetClass.cpus.size() * static_cast(descriptor.fragmentsPerLane)) - return nullptr; + if (failed(verifyProjectedTransferDescriptor(targetClass.op, descriptor))) + return failure(); + if (descriptor.fragmentOffsets.size() + != targetClass.cpus.size() * static_cast(descriptor.layout.payloadFragmentCount)) + return targetClass.op->emitError("inconsistent batch projected transfer descriptor"); - return &descriptor; + return std::optional {descriptor}; }; for (ClassId destinationClass : destinationClasses) { @@ -2220,41 +2642,52 @@ FailureOr buildScalarSourceFanoutPlan(MaterializerState& return failure(); } - if (const ProjectedTransferDescriptor* descriptor = getProjectedDescriptor(destinationClass)) { - receivePlan.receiveType = descriptor->payloadType; - receivePlan.projectedExtractOp = descriptor->extractOp; - receivePlan.projectedFragmentType = descriptor->fragmentType; - receivePlan.projectedFragmentsPerLane = descriptor->fragmentsPerLane; + FailureOr> descriptor = getProjectedDescriptor(destinationClass); + if (failed(descriptor)) + return failure(); - ProjectedTransferCompatibilityKey key {descriptor->fragmentType, - descriptor->payloadType, - descriptor->sourceProjectedDim, - descriptor->fragmentsPerLane}; + if (*descriptor) { + const ProjectedTransferDescriptor& projectedDescriptor = **descriptor; - auto groupIt = llvm::find_if(fanoutPlan.projectedSendGroups, [&](const ScalarSourceSendGroup& group) { - return group.projectedKey && *group.projectedKey == key; + if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType()) + return targetClass.op->emitError("scalar projected receive unexpectedly uses the full producer tensor type"); + + receivePlan.receiveType = projectedDescriptor.payloadType; + receivePlan.projectedExtractOp = projectedDescriptor.extractOp; + receivePlan.projectedLayout = projectedDescriptor.layout; + + auto groupIt = llvm::find_if(fanoutPlan.projectedSendGroups, [&](const ProjectedScalarSendGroup& group) { + return hasSameProjectedSendCompatibility(group.descriptor, projectedDescriptor); }); if (groupIt == fanoutPlan.projectedSendGroups.end()) { - ScalarSourceSendGroup group; - group.projectedKey = key; + ProjectedScalarSendGroup group; + group.descriptor.layout = projectedDescriptor.layout; + group.descriptor.payloadType = projectedDescriptor.payloadType; fanoutPlan.projectedSendGroups.push_back(std::move(group)); groupIt = std::prev(fanoutPlan.projectedSendGroups.end()); } groupIt->messages.append( receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds); - llvm::append_range(groupIt->projectedOffsets, descriptor->laneMajorProjectedOffsets); + llvm::append_range(groupIt->descriptor.fragmentOffsets, projectedDescriptor.fragmentOffsets); } else { - if (!fanoutPlan.ordinarySendGroup) - fanoutPlan.ordinarySendGroup = ScalarSourceSendGroup {}; - fanoutPlan.ordinarySendGroup->messages.append( + if (!fanoutPlan.ordinaryMessages) + fanoutPlan.ordinaryMessages = MessageVector {}; + fanoutPlan.ordinaryMessages->append( receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds); } fanoutPlan.receivePlans.push_back(std::move(receivePlan)); } + for (ProjectedScalarSendGroup& group : fanoutPlan.projectedSendGroups) { + if (failed(finalizeProjectedTransferDescriptor(sourceClass.op, group.descriptor))) + return failure(); + if (failed(verifyProjectedSendDescriptor(sourceClass.op, group.descriptor, group.messages))) + return failure(); + } + return fanoutPlan; } @@ -2263,23 +2696,12 @@ LogicalResult emitScalarSourceFanoutSends(MaterializerState& state, Value payload, const ScalarSourceFanoutPlan& plan, Location loc) { - if (plan.ordinarySendGroup && failed(appendSend(state, sourceClass, payload, plan.ordinarySendGroup->messages, loc))) + if (plan.ordinaryMessages && failed(appendSend(state, sourceClass, payload, *plan.ordinaryMessages, loc))) return failure(); - for (const ScalarSourceSendGroup& group : plan.projectedSendGroups) { - if (!group.projectedKey) - return sourceClass.op->emitError("projected scalar send group is missing a compatibility key"); - - ProjectedTransferDescriptor descriptor; - descriptor.fragmentType = group.projectedKey->fragmentType; - descriptor.payloadType = group.projectedKey->payloadType; - descriptor.sourceProjectedDim = group.projectedKey->sourceProjectedDim; - descriptor.fragmentsPerLane = group.projectedKey->fragmentsPerLane; - descriptor.laneMajorProjectedOffsets = group.projectedOffsets; - - if (failed(appendProjectedScalarSendLoop(state, sourceClass, payload, descriptor, group.messages, loc))) + for (const ProjectedScalarSendGroup& group : plan.projectedSendGroups) + if (failed(appendProjectedScalarSendLoop(state, sourceClass, payload, group.descriptor, group.messages, loc))) return failure(); - } return success(); } @@ -2305,7 +2727,7 @@ LogicalResult emitScalarSourceCommunication( if (plan.projectedExtractOp) { state.projectedExtractReplacements[plan.projectedExtractOp][plan.targetClass] = - ProjectedExtractReplacement {received, plan.projectedFragmentType, plan.projectedFragmentsPerLane}; + ProjectedExtractReplacement {received, plan.projectedLayout}; continue; } @@ -2438,7 +2860,7 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val LogicalResult emitHostCommunication(MaterializerState& state, MaterializedClass& sourceClass, Value payload, Value originalOutput) { - if (!hasLiveExternalUse(originalOutput, state.oldComputeOps)) + if (!hasLiveExternalUseCached(state, originalOutput)) return success(); return setHostOutputValue(state, sourceClass, originalOutput, payload); @@ -3091,11 +3513,11 @@ bool hasProjectedInputReplacement(MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex, ClassId classId) { - std::optional extract = matchSimpleLaneProjectedInput(batch, inputIndex); - if (!extract) + std::optional match = getProjectedInputSliceMatch(state, batch, inputIndex); + if (!match) return false; - auto replacementIt = state.projectedExtractReplacements.find(extract->getOperation()); + auto replacementIt = state.projectedExtractReplacements.find(match->extract.getOperation()); if (replacementIt == state.projectedExtractReplacements.end()) return false; @@ -3222,6 +3644,20 @@ SmallVector collectBatchOutputFragmentTypes(SpatComputeBatch batch) { return types; } +SmallVector& getBatchOutputFragmentTypesCached(MaterializerState& state, SpatComputeBatch batch) { + auto [it, inserted] = state.batchOutputFragmentTypesCache.try_emplace(batch.getOperation(), SmallVector {}); + if (inserted) + it->second = collectBatchOutputFragmentTypes(batch); + return it->second; +} + +ArrayRef getComputeInstanceOutputValuesCached(MaterializerState& state, ComputeInstance instance) { + auto [it, inserted] = state.computeInstanceOutputsCache.try_emplace(instance, SmallVector {}); + if (inserted) + it->second = getComputeInstanceOutputValues(instance); + return it->second; +} + std::optional lookupProjectedExtractReplacement(MaterializerState& state, MaterializedClass& targetClass, tensor::ExtractSliceOp extract) { @@ -3236,21 +3672,166 @@ std::optional lookupProjectedExtractReplacement(Mat return classIt->second; } +LogicalResult applyProjectedExtractReplacementsInClonedOp(MaterializerState& state, + MaterializedClass& targetClass, + Operation& originalOp, + Operation& clonedOp, + CloneIndexingContext indexing) { + if (auto originalExtract = dyn_cast(&originalOp)) { + if (std::optional replacement = + lookupProjectedExtractReplacement(state, targetClass, originalExtract)) { + auto clonedExtract = dyn_cast(&clonedOp); + if (!clonedExtract) + return targetClass.op->emitError("projected replacement lost extract structure during cloning"); + + state.rewriter.setInsertionPoint(clonedExtract); + FailureOr projected = materializeProjectedExtractReplacement( + state, targetClass, clonedExtract, *replacement, indexing.projectionSlotIndex); + if (failed(projected)) + return failure(); + + clonedExtract.getResult().replaceAllUsesWith(*projected); + state.rewriter.eraseOp(clonedExtract); + return success(); + } + } + + if (originalOp.getNumRegions() != clonedOp.getNumRegions()) + return targetClass.op->emitError("projected replacement traversal found non-isomorphic cloned regions"); + + for (auto [originalRegion, clonedRegion] : llvm::zip(originalOp.getRegions(), clonedOp.getRegions())) { + if (std::distance(originalRegion.begin(), originalRegion.end()) + != std::distance(clonedRegion.begin(), clonedRegion.end())) + return targetClass.op->emitError("projected replacement traversal found non-isomorphic cloned blocks"); + + for (auto [originalBlock, clonedBlock] : llvm::zip(originalRegion.getBlocks(), clonedRegion.getBlocks())) { + auto originalIt = originalBlock.begin(); + auto clonedIt = clonedBlock.begin(); + while (originalIt != originalBlock.end() && clonedIt != clonedBlock.end()) { + Operation& originalNestedOp = *originalIt++; + Operation* currentClonedOp = &*clonedIt++; + if (failed(applyProjectedExtractReplacementsInClonedOp( + state, targetClass, originalNestedOp, *currentClonedOp, indexing))) + return failure(); + } + if (originalIt != originalBlock.end() || clonedIt != clonedBlock.end()) + return targetClass.op->emitError("projected replacement traversal found mismatched cloned operations"); + } + } + + return success(); +} + +LogicalResult cloneComputeTemplateBody(MaterializerState& state, + MaterializedClass& targetClass, + const ComputeInstance& instance, + IRMapping& mapper, + CloneIndexingContext indexing) { + Block& sourceBlock = getComputeInstanceTemplateBlock(instance); + for (Operation& op : sourceBlock.without_terminator()) { + if (auto extract = dyn_cast(&op)) { + if (std::optional replacement = + lookupProjectedExtractReplacement(state, targetClass, extract)) { + FailureOr projected = materializeProjectedExtractReplacement( + state, targetClass, extract, *replacement, indexing.projectionSlotIndex); + if (failed(projected)) + return failure(); + + mapper.map(extract.getResult(), *projected); + continue; + } + } + + Operation* cloned = state.rewriter.clone(op, mapper); + if (op.getNumRegions() != 0 + && failed(applyProjectedExtractReplacementsInClonedOp(state, targetClass, op, *cloned, indexing))) + return failure(); + for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults())) + mapper.map(oldResult, newResult); + } + + return success(); +} + FailureOr materializeProjectedExtractReplacement(MaterializerState& state, MaterializedClass& targetClass, tensor::ExtractSliceOp extract, const ProjectedExtractReplacement& replacement, std::optional projectionSlotIndex) { - if (replacement.fragmentsPerLane == 1) + if (failed(verifyProjectedFragmentLayout(targetClass.op, replacement.layout))) + return failure(); + if (replacement.layout.payloadFragmentCount == 1) return replacement.payload; - if (!projectionSlotIndex) - return targetClass.op->emitError("packed projected extract replacement requires a projection slot index"); + if (replacement.layout.payloadFragmentCount < replacement.layout.fragmentsPerLogicalSlot) + return targetClass.op->emitError("projected replacement payload is smaller than one logical slot"); + + Value intraSlotFragmentIndex = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + const auto linearizeProjectedLoopIndices = [&]() -> FailureOr { + if (replacement.layout.loopTripCounts.empty()) + return intraSlotFragmentIndex; + + SmallVector surroundingLoops; + for (Operation* current = extract->getParentOp(); current; current = current->getParentOp()) { + if (auto loop = dyn_cast(current)) + surroundingLoops.push_back(loop); + if (current == targetClass.op) + break; + } + std::reverse(surroundingLoops.begin(), surroundingLoops.end()); + + if (surroundingLoops.size() != replacement.layout.loopTripCounts.size()) + return targetClass.op->emitError("projected replacement loop structure does not match the collected descriptor"); + + Value linearizedIndex = intraSlotFragmentIndex; + for (auto [index, loop] : llvm::enumerate(surroundingLoops)) { + Value iv = loop.getInductionVar(); + Value lowerBound = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopLowerBounds[index]); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopSteps[index]); + Value tripCount = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopTripCounts[index]); + + Value normalized = arith::SubIOp::create(state.rewriter, extract.getLoc(), iv, lowerBound).getResult(); + if (replacement.layout.loopSteps[index] != 1) + normalized = arith::DivUIOp::create(state.rewriter, extract.getLoc(), normalized, step).getResult(); + linearizedIndex = arith::MulIOp::create(state.rewriter, extract.getLoc(), linearizedIndex, tripCount).getResult(); + linearizedIndex = + arith::AddIOp::create(state.rewriter, extract.getLoc(), linearizedIndex, normalized).getResult(); + } + return linearizedIndex; + }; + + FailureOr linearizedIndex = linearizeProjectedLoopIndices(); + if (failed(linearizedIndex)) + return failure(); + intraSlotFragmentIndex = *linearizedIndex; + + const auto computeProjectedPayloadFragmentIndex = [&]() -> FailureOr { + if (replacement.layout.payloadFragmentCount == replacement.layout.fragmentsPerLogicalSlot) { + if (replacement.layout.loopTripCounts.empty() && replacement.layout.fragmentsPerLogicalSlot != 1) + return targetClass.op->emitError("projected replacement is missing loop metadata for packed logical slot"); + return intraSlotFragmentIndex; + } + + if (!projectionSlotIndex) + return targetClass.op->emitError("packed projected extract replacement requires a fragment slot index"); + + Value fragmentsPerLogicalSlot = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.fragmentsPerLogicalSlot); + Value base = arith::MulIOp::create(state.rewriter, extract.getLoc(), *projectionSlotIndex, fragmentsPerLogicalSlot) + .getResult(); + return arith::AddIOp::create(state.rewriter, extract.getLoc(), base, intraSlotFragmentIndex).getResult(); + }; + + FailureOr packedFragmentIndex = computeProjectedPayloadFragmentIndex(); + if (failed(packedFragmentIndex)) + return failure(); Value packedOffset = scaleIndexByDim0Size( - state, targetClass.op, *projectionSlotIndex, replacement.fragmentType.getDimSize(0), extract.getLoc()); + state, targetClass.op, *packedFragmentIndex, replacement.layout.fragmentType.getDimSize(0), extract.getLoc()); return createDim0ExtractSlice( - state, extract.getLoc(), replacement.payload, packedOffset, replacement.fragmentType.getDimSize(0)); + state, extract.getLoc(), replacement.payload, packedOffset, replacement.layout.fragmentType.getDimSize(0)); } FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, @@ -3264,17 +3845,20 @@ FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, return failure(); Value flatIndex = createBatchRunFlatIndex(state, targetClass, runSlotIndex, loc); - Value channelId = createIndexedChannelId(state, targetClass.op, run.messages, flatIndex, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, run.messages, flatIndex, loc); - Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, run.messages, flatIndex, loc); + std::optional preferredPeriod = static_cast(targetClass.cpus.size()); + Value channelId = createIndexedChannelId(state, targetClass.op, run.messages, flatIndex, loc, preferredPeriod); + Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, run.messages, flatIndex, loc, preferredPeriod); + Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, run.messages, flatIndex, loc, preferredPeriod); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); return SpatChannelReceiveOp::create(state.rewriter, loc, run.fragmentType, channelId, sourceCoreId, targetCoreId) .getOutput(); } -FailureOr> -cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, ArrayRef peers) { +FailureOr> cloneInstanceBody(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef peers, + CloneIndexingContext indexing) { assert(!peers.empty() && "expected at least one peer instance"); const ComputeInstance& instance = peers.front(); Operation* sourceOp = instance.op; @@ -3301,32 +3885,15 @@ cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, Arra OpBuilder::InsertPoint cloneInsertionPoint = state.rewriter.saveInsertionPoint(); mapWeights(state, targetClass, instance, mapper); - if (failed(mapInputs(state, targetClass, instance, mapper, {}))) + if (failed(mapInputs(state, targetClass, instance, mapper, indexing))) return failure(); state.rewriter.restoreInsertionPoint(cloneInsertionPoint); - - Block& sourceBlock = getComputeInstanceTemplateBlock(instance); - for (Operation& op : sourceBlock.without_terminator()) { - if (auto extract = dyn_cast(&op)) { - if (std::optional replacement = - lookupProjectedExtractReplacement(state, targetClass, extract)) { - FailureOr projected = - materializeProjectedExtractReplacement(state, targetClass, extract, *replacement, std::nullopt); - if (failed(projected)) - return failure(); - - mapper.map(extract.getResult(), *projected); - continue; - } - } - - Operation* cloned = state.rewriter.clone(op, mapper); - for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults())) - mapper.map(oldResult, newResult); - } + if (failed(cloneComputeTemplateBody(state, targetClass, instance, mapper, indexing))) + return failure(); if (auto compute = dyn_cast(sourceOp)) { + Block& sourceBlock = getComputeInstanceTemplateBlock(instance); auto yield = dyn_cast_or_null(sourceBlock.getTerminator()); if (!yield) { compute.emitOpError("expected spat.yield terminator while materializing compute"); @@ -3385,7 +3952,7 @@ SmallVector groupBatchRunOutputsByDestination(Materia assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); SmallVector groups; - SmallVector outputs = getComputeInstanceOutputValues(run.front().peers.front()); + ArrayRef outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front()); for (auto [resultIndex, output] : llvm::enumerate(outputs)) { SmallVector destinations = collectDestinationClassesForRun(state, run, resultIndex); @@ -3535,7 +4102,7 @@ LogicalResult emitPackedRunFanout(MaterializerState& state, if (plan.projectedExtractOp) { state.projectedExtractReplacements[plan.projectedExtractOp][plan.targetClass] = - ProjectedExtractReplacement {received, plan.projectedFragmentType, plan.projectedFragmentsPerLane}; + ProjectedExtractReplacement {received, plan.projectedLayout}; continue; } @@ -3570,26 +4137,8 @@ FailureOr> cloneBatchBodyForLane(MaterializerState& state, return failure(); state.rewriter.restoreInsertionPoint(cloneInsertionPoint); - - Block& sourceBlock = getComputeInstanceTemplateBlock(instance); - for (Operation& op : sourceBlock.without_terminator()) { - if (auto extract = dyn_cast(&op)) { - if (std::optional replacement = - lookupProjectedExtractReplacement(state, targetClass, extract)) { - FailureOr projected = materializeProjectedExtractReplacement( - state, targetClass, extract, *replacement, indexing.projectionSlotIndex); - if (failed(projected)) - return failure(); - - mapper.map(extract.getResult(), *projected); - continue; - } - } - - Operation* cloned = state.rewriter.clone(op, mapper); - for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults())) - mapper.map(oldResult, newResult); - } + if (failed(cloneComputeTemplateBody(state, targetClass, instance, mapper, indexing))) + return failure(); SmallVector allOutputs = collectMappedBatchOutputs(batch, mapper); if (allOutputs.empty() && !resultIndices.empty()) @@ -3630,7 +4179,7 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); auto sourceBatch = cast(sourceOp); - SmallVector fragmentTypes = collectBatchOutputFragmentTypes(sourceBatch); + SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); SmallVector initValues; for (size_t resultIndex : group.resultIndices) { @@ -3681,7 +4230,7 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta run.front().peers.front(), sourceLane, group.resultIndices, - CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = loopIndex}); + CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex}); if (failed(produced)) return failure(); @@ -3769,10 +4318,11 @@ SmallVector getMaterializationRunOutputKeys(ArrayRef getFirstMaterializationRunOriginalOutputs(ArrayRef run) { +ArrayRef getFirstMaterializationRunOriginalOutputs(MaterializerState& state, + ArrayRef run) { assert(!run.empty() && "expected non-empty materialization run"); assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); - return getComputeInstanceOutputValues(run.front().peers.front()); + return getComputeInstanceOutputValuesCached(state, run.front().peers.front()); } Operation* getMaterializationRunSourceOp(ArrayRef run) { @@ -3790,11 +4340,11 @@ bool hasMaterializationRunResultLiveExternalUse(MaterializerState& state, size_t resultIndex) { for (const MaterializationRunSlot& slot : run) { for (const ComputeInstance& peer : slot.peers) { - SmallVector outputs = getComputeInstanceOutputValues(peer); + ArrayRef outputs = getComputeInstanceOutputValuesCached(state, peer); if (resultIndex >= outputs.size()) return true; - if (hasLiveExternalUse(outputs[resultIndex], state.oldComputeOps)) + if (hasLiveExternalUseCached(state, outputs[resultIndex])) return true; } } @@ -3847,10 +4397,10 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state, markMaterializationRunSlots(state, targetClass.id, startSlot, run); SmallVector groups = groupBatchRunOutputsByDestination(state, run); - SmallVector firstOriginalOutputs = getFirstMaterializationRunOriginalOutputs(run); + ArrayRef firstOriginalOutputs = getFirstMaterializationRunOriginalOutputs(state, run); auto sourceBatch = cast(getMaterializationRunSourceOp(run)); - SmallVector fragmentTypes = collectBatchOutputFragmentTypes(sourceBatch); + SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); Location loc = getMaterializationRunLoc(run); for (const OutputDestinationGroup& group : groups) { @@ -3897,10 +4447,10 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state, for (auto [runIndex, slot] : llvm::enumerate(run)) { assert(slot.peers.size() == 1 && "scalar materialization run slot must contain exactly one peer"); - SmallVector originalOutputs = getComputeInstanceOutputValues(slot.peers.front()); + ArrayRef originalOutputs = getComputeInstanceOutputValuesCached(state, slot.peers.front()); Value originalOutput = originalOutputs[resultIndex]; - if (!hasLiveExternalUse(originalOutput, state.oldComputeOps)) + if (!hasLiveExternalUseCached(state, originalOutput)) continue; state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); @@ -3935,7 +4485,7 @@ bool canCompactBatchClassRun(MaterializerState& state, if (run.front().peers.empty()) return false; - SmallVector outputs = getComputeInstanceOutputValues(run.front().peers.front()); + ArrayRef outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front()); for (auto [resultIndex, ignored] : llvm::enumerate(outputs)) { for (const MaterializationRunSlot& slot : run) { @@ -3943,12 +4493,12 @@ bool canCompactBatchClassRun(MaterializerState& state, return false; for (const ComputeInstance& peer : slot.peers) { - SmallVector peerOutputs = getComputeInstanceOutputValues(peer); + ArrayRef peerOutputs = getComputeInstanceOutputValuesCached(state, peer); if (resultIndex >= peerOutputs.size()) return false; Value originalOutput = peerOutputs[resultIndex]; - if (hasLiveExternalUse(originalOutput, state.oldComputeOps)) + if (hasLiveExternalUseCached(state, originalOutput)) return false; ProducerKey key {peer, resultIndex}; @@ -3990,7 +4540,13 @@ Value createBatchClassRunSourceLane(MaterializerState& state, } Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); - return createIndexedIndexValue(state, targetClass.op, sourceLanes, flatIndex, loc); + return createIndexedIndexValue(state, + targetClass.op, + sourceLanes, + flatIndex, + loc, + static_cast(targetClass.cpus.size()), + /*allowExhaustiveTiledSearch=*/false); } LogicalResult buildBatchRunSendPlans(MaterializerState& state, @@ -4051,9 +4607,10 @@ void appendBatchRunSend(MaterializerState& state, Location loc) { assert(sourceClass.isBatch && "batch run send expects a materialized batch source"); - Value channelId = createIndexedChannelId(state, sourceClass.op, plan.messages, flatIndex, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, plan.messages, flatIndex, loc); - Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, plan.messages, flatIndex, loc); + std::optional preferredPeriod = static_cast(sourceClass.cpus.size()); + Value channelId = createIndexedChannelId(state, sourceClass.op, plan.messages, flatIndex, loc, preferredPeriod); + Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, plan.messages, flatIndex, loc, preferredPeriod); + Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, plan.messages, flatIndex, loc, preferredPeriod); SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); } @@ -4157,7 +4714,7 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, SmallVector groups = groupBatchRunOutputsByDestination(state, run); auto sourceBatch = cast(run.front().peers.front().op); - SmallVector fragmentTypes = collectBatchOutputFragmentTypes(sourceBatch); + SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); Location loc = sourceBatch.getLoc(); for (const OutputDestinationGroup& group : groups) { @@ -4257,11 +4814,17 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeIns if (failed(peers)) return instance.op->emitError("failed to collect peer compute instances for equivalence class logical slot"); - FailureOr> materializedOutputs = cloneInstanceBody(state, targetClass, *peers); + Value projectionSlotIndex = getOrCreateIndexConstant( + state.constantFolder, targetClass.op, static_cast(startLogicalSlot - logicalRange.start)); + FailureOr> materializedOutputs = + cloneInstanceBody(state, + targetClass, + *peers, + CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = projectionSlotIndex}); if (failed(materializedOutputs)) return failure(); - SmallVector originalOutputs = getComputeInstanceOutputValues(instance); + ArrayRef originalOutputs = getComputeInstanceOutputValuesCached(state, instance); if (materializedOutputs->size() != originalOutputs.size()) return instance.op->emitError("materialized output count does not match original compute instance output count");