From 645539317b414cc309e9c296e0f8cef0f855994c Mon Sep 17 00:00:00 2001 From: ilgeco Date: Mon, 29 Jun 2026 15:21:28 +0200 Subject: [PATCH] Fix BB Arg used as input in external Op --- .../MaterializeMergeSchedule.cpp | 823 ++++++++++-------- 1 file changed, 456 insertions(+), 367 deletions(-) diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 31219e5..1d78c29 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -3,6 +3,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -11,9 +12,9 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/StringRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" #include @@ -134,12 +135,14 @@ struct SameClassConsumerLookupKey { struct SameClassConsumerLookupKeyInfo { static SameClassConsumerLookupKey getEmptyKey() { - return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max(), + return {llvm::DenseMapInfo::getEmptyKey(), + std::numeric_limits::max(), std::numeric_limits::max()}; } static SameClassConsumerLookupKey getTombstoneKey() { - return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max(), + return {llvm::DenseMapInfo::getTombstoneKey(), + std::numeric_limits::max(), std::numeric_limits::max()}; } @@ -164,12 +167,14 @@ struct WholeBatchAssemblyLookupKey { struct WholeBatchAssemblyLookupKeyInfo { static WholeBatchAssemblyLookupKey getEmptyKey() { - return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max(), + return {llvm::DenseMapInfo::getEmptyKey(), + std::numeric_limits::max(), std::numeric_limits::max()}; } static WholeBatchAssemblyLookupKey getTombstoneKey() { - return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max(), + return {llvm::DenseMapInfo::getTombstoneKey(), + std::numeric_limits::max(), std::numeric_limits::max()}; } @@ -407,11 +412,8 @@ FailureOr materializeTensorValueForMaterializedClassUse(MaterializerState StringRef context, std::optional producer = std::nullopt, IRMapping* mapper = nullptr); -FailureOr materializeWholeBatchInput(MaterializerState& state, - MaterializedClass& targetClass, - ProducerKey key, - Type resultType, - Location loc); +FailureOr materializeWholeBatchInput( + MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, Type resultType, Location loc); FailureOr localizeMaterializedClassOperand(MaterializerState& state, MaterializedClass& targetClass, Value value, @@ -428,15 +430,13 @@ void createDim0ParallelInsertSlice( MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset); Value scaleIndexByDim0Size(MaterializerState& state, Operation* anchor, Value index, int64_t dim0Size, Location loc); bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch, - const AffineProjectedInputSliceMatch& match, - ProducerKey producer, - uint32_t consumerLane); -std::optional getProjectedInputSliceMatch(MaterializerState& state, - SpatComputeBatch batch, - unsigned inputIndex); -std::optional getProjectedWholeBatchReplacementProducer(MaterializerState& state, - SpatComputeBatch batch, - unsigned inputIndex); + const AffineProjectedInputSliceMatch& match, + ProducerKey producer, + uint32_t consumerLane); +std::optional +getProjectedInputSliceMatch(MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex); +std::optional +getProjectedWholeBatchReplacementProducer(MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex); std::optional getProjectedWholeBatchReplacementProducer(MaterializerState& state, tensor::ExtractSliceOp extract); FailureOr materializeProjectedWholeBatchExtractReplacement(MaterializerState& state, @@ -541,9 +541,7 @@ struct MaterializerState { SmallVector pendingProjectedHostOutputFragments; DenseSet oldComputeOps; - MaterializerState(func::FuncOp func, - const MergeScheduleResult& schedule, - int64_t& nextChannelId) + MaterializerState(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) : func(func), schedule(schedule), rewriter(func.getContext()), @@ -693,7 +691,8 @@ getPublicationLaneForProducerKey(MaterializerState& state, const MaterializedCla ComputeInstance scheduledProducer = getScheduledChunkForLogicalInstance(state, key.instance); auto cpuIt = state.schedule.computeToCpuMap.find(scheduledProducer); if (cpuIt == state.schedule.computeToCpuMap.end()) { - sourceClass.op->emitError("projected packed host publication could not resolve the producer CPU for a publication lane") + sourceClass.op->emitError( + "projected packed host publication could not resolve the producer CPU for a publication lane") << " laneStart=" << key.instance.laneStart << " laneCount=" << key.instance.laneCount << " resultIndex=" << key.resultIndex; return failure(); @@ -866,8 +865,7 @@ bool canUseProjectedLaneInput(MaterializerState& state, if (!producerBatch || producerBatch.getNumResults() == 0) return false; - std::optional match = - getProjectedInputSliceMatch(state, consumerBatch, inputIndex); + std::optional match = getProjectedInputSliceMatch(state, consumerBatch, inputIndex); if (!match) return false; @@ -885,34 +883,31 @@ FailureOr classifyComputeBatchInputDemand(MaterializerState& s ComputeInstance logicalConsumer) { if (std::optional wholeBatchProducer = getWholeBatchProducerKeyForDirectBatchResult(input)) { if (canUseProjectedLaneInput(state, consumerBatch, inputIndex, input, logicalConsumer)) - return BatchInputDemand { - .kind = BatchInputDemandKind::ProjectedFragment, .wholeTensorProducer = std::nullopt}; + return BatchInputDemand {.kind = BatchInputDemandKind::ProjectedFragment, .wholeTensorProducer = std::nullopt}; if (getProjectedWholeBatchReplacementProducer(state, consumerBatch, inputIndex)) - return BatchInputDemand { - .kind = BatchInputDemandKind::ProjectedFragment, .wholeTensorProducer = std::nullopt}; + return BatchInputDemand {.kind = BatchInputDemandKind::ProjectedFragment, .wholeTensorProducer = std::nullopt}; auto inputArg = consumerBatch.getInputArgument(inputIndex); if (!inputArg) return consumerBatch.emitOpError("expected compute_batch input block argument while classifying input demand") - << " #" << inputIndex; + << " #" << inputIndex; bool hasUses = false; for (OpOperand& use : inputArg->getUses()) { hasUses = true; if (!isa(use.getOwner())) - return BatchInputDemand { - .kind = BatchInputDemandKind::WholeTensorBarrier, .wholeTensorProducer = wholeBatchProducer}; + return BatchInputDemand {.kind = BatchInputDemandKind::WholeTensorBarrier, + .wholeTensorProducer = wholeBatchProducer}; } if (!hasUses) return BatchInputDemand {.kind = BatchInputDemandKind::LaneFragment, .wholeTensorProducer = std::nullopt}; return targetClass.op->emitError("failed to classify compute_batch input demand") - << " reason=direct whole-batch input only has projected uses, but no projected fragment path was proven" - << " consumerOp='" << consumerBatch->getName() << "' inputIndex=" << inputIndex - << " producerOp='" << wholeBatchProducer->instance.op->getName() << "' resultIndex=" - << wholeBatchProducer->resultIndex << " sourceClass=" << targetClass.id - << " valueType=" << input.getType(); + << " reason=direct whole-batch input only has projected uses, but no projected fragment path was proven" + << " consumerOp='" << consumerBatch->getName() << "' inputIndex=" << inputIndex << " producerOp='" + << wholeBatchProducer->instance.op->getName() << "' resultIndex=" << wholeBatchProducer->resultIndex + << " sourceClass=" << targetClass.id << " valueType=" << input.getType(); } return BatchInputDemand {.kind = BatchInputDemandKind::LaneFragment, .wholeTensorProducer = std::nullopt}; @@ -1206,7 +1201,8 @@ LogicalResult createEmptyMaterializedOps(MaterializerState& state) { resultTypes.push_back(output.getType()); if (!materializedClass.isBatch) { - auto compute = SpatScheduledCompute::create(state.rewriter, loc, TypeRange(resultTypes), ValueRange {}, ValueRange {}); + auto compute = + SpatScheduledCompute::create(state.rewriter, loc, TypeRange(resultTypes), ValueRange {}, ValueRange {}); compute.getProperties().setOperandSegmentSizes({0, 0}); auto coreIdAttr = pim::getCheckedI32Attr(state.rewriter, state.func, materializedClass.cpus.front(), "materialized core id"); @@ -1317,6 +1313,15 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali return std::get<1>(*arg); } +bool isOldComputeRegionBlockArgument(MaterializerState& state, Value value) { + auto blockArg = dyn_cast(value); + if (!blockArg) + return false; + + Operation* owner = blockArg.getOwner()->getParentOp(); + return owner && state.oldComputeOps.contains(owner); +} + FailureOr appendScalarPublicationResult(MaterializerState& state, MaterializedClass& materializedClass, Value payload, @@ -1369,8 +1374,7 @@ FailureOr appendBatchPublicationResult(MaterializerState& state, SmallVector publishedShape(payloadType.getShape()); publishedShape[0] *= static_cast(materializedClass.cpus.size()); - auto publishedType = - RankedTensorType::get(publishedShape, payloadType.getElementType(), payloadType.getEncoding()); + auto publishedType = RankedTensorType::get(publishedShape, payloadType.getElementType(), payloadType.getEncoding()); FailureOr> inserted = batch.insertOutput(state.rewriter, batch.getNumResults(), publishedType, loc); @@ -1397,8 +1401,7 @@ FailureOr appendBatchPublicationResult(MaterializerState& state, } state.rewriter.setInsertionPoint(inParallelOp); - Value firstOffset = - scaleIndexByDim0Size(state, materializedClass.op, *laneArg, payloadType.getDimSize(0), loc); + Value firstOffset = scaleIndexByDim0Size(state, materializedClass.op, *laneArg, payloadType.getDimSize(0), loc); state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); createDim0ParallelInsertSlice(state, loc, payload, outputArg, firstOffset); return result.getResultNumber(); @@ -1551,7 +1554,8 @@ std::string formatMaterializerOperandListInline(Operation* op, const Materialize stream << " blockArg#" << blockArg.getArgNumber(); if (Operation* owner = blockArg.getOwner()->getParentOp()) stream << " ownerOp='" << owner->getName() << "'"; - } else if (Operation* definingOp = value.getDefiningOp()) { + } + else if (Operation* definingOp = value.getDefiningOp()) { stream << " definingOp='" << definingOp->getName() << "'"; } } @@ -1610,7 +1614,8 @@ void attachMaterializerOperandListNote(InFlightDiagnostic& diagnostic, stream << " blockArg#" << blockArg.getArgNumber(); if (Operation* owner = blockArg.getOwner()->getParentOp()) stream << " ownerOp='" << owner->getName() << "'"; - } else if (Operation* definingOp = value.getDefiningOp()) { + } + else if (Operation* definingOp = value.getDefiningOp()) { stream << " definingOp='" << definingOp->getName() << "'"; } stream << "\n"; @@ -1622,13 +1627,12 @@ void attachMaterializerOperandListNote(InFlightDiagnostic& diagnostic, void attachMaterializerValueOriginNote(InFlightDiagnostic& diagnostic, Value value, StringRef label) { if (auto blockArg = dyn_cast(value)) { if (Operation* owner = blockArg.getOwner()->getParentOp()) - diagnostic.attachNote(owner->getLoc()) - << label << " is block argument #" << blockArg.getArgNumber() << " of '" << owner->getName() - << "' with type " << blockArg.getType(); + diagnostic.attachNote(owner->getLoc()) << label << " is block argument #" << blockArg.getArgNumber() << " of '" + << owner->getName() << "' with type " << blockArg.getType(); else diagnostic.attachNote(UnknownLoc::get(value.getContext())) - << label << " is a top-level block argument #" << blockArg.getArgNumber() - << " with type " << blockArg.getType(); + << label << " is a top-level block argument #" << blockArg.getArgNumber() << " with type " + << blockArg.getType(); return; } @@ -1645,16 +1649,13 @@ void attachMaterializerValueOriginNote(InFlightDiagnostic& diagnostic, Value val void attachMaterializedClassBodySummary(InFlightDiagnostic& diagnostic, const MaterializedClass& targetClass) { Block& body = *targetClass.body; diagnostic.attachNote(targetClass.op->getLoc()) - << "target class " << targetClass.id << " op '" << targetClass.op->getName() - << "' body has " << body.getNumArguments() << " block arguments and " - << std::distance(body.begin(), body.end()) << " top-level operations"; + << "target class " << targetClass.id << " op '" << targetClass.op->getName() << "' body has " + << body.getNumArguments() << " block arguments and " << std::distance(body.begin(), body.end()) + << " top-level operations"; } -FailureOr rematerializeIndexValueInClass(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Location loc, - IRMapping* mapper = nullptr); +FailureOr rematerializeIndexValueInClass( + MaterializerState& state, MaterializedClass& targetClass, Value value, Location loc, IRMapping* mapper = nullptr); FailureOr rematerializeIndexOpFoldResultInClass(MaterializerState& state, MaterializedClass& targetClass, @@ -1670,11 +1671,8 @@ FailureOr rematerializeIndexOpFoldResultInClass(MaterializerState& return OpFoldResult(*rematerialized); } -FailureOr rematerializeIndexValueInClass(MaterializerState& state, - MaterializedClass& targetClass, - Value value, - Location loc, - IRMapping* mapper) { +FailureOr rematerializeIndexValueInClass( + MaterializerState& state, MaterializedClass& targetClass, Value value, Location loc, IRMapping* mapper) { Value originalValue = value; bool mapperHadOriginalValue = false; Value mappedOriginalValue; @@ -1693,7 +1691,7 @@ FailureOr rematerializeIndexValueInClass(MaterializerState& state, if (!value.getType().isIndex()) return targetClass.op->emitError("cannot rematerialize non-index external value in materialized class body") - << " type=" << value.getType(); + << " type=" << value.getType(); if (auto constantIndex = value.getDefiningOp()) return getOrCreateIndexConstant(state.constantFolder, targetClass.op, constantIndex.value()); @@ -1702,7 +1700,7 @@ FailureOr rematerializeIndexValueInClass(MaterializerState& state, if (matchPattern(value, m_ConstantInt(&constantValue))) { if (!constantValue.isSignedIntN(64)) return targetClass.op->emitError("cannot rematerialize out-of-range index constant") - << " value=" << llvm::toString(constantValue, 10, /*Signed=*/true); + << " value=" << llvm::toString(constantValue, 10, /*Signed=*/true); return getOrCreateIndexConstant(state.constantFolder, targetClass.op, constantValue.getSExtValue()); } @@ -1763,26 +1761,29 @@ FailureOr rematerializeIndexValueInClass(MaterializerState& state, Value tensor = extractOp.getTensor(); if (!isConstantLike(tensor) && !isValueLegalInMaterializedClassBody(tensor, targetClass)) return targetClass.op->emitError("cannot rematerialize indexed table lookup from external non-constant tensor") - << " tensorType=" << tensor.getType(); + << " tensorType=" << tensor.getType(); return tensor::ExtractOp::create(state.rewriter, loc, tensor, remappedIndices).getResult(); } if (auto blockArg = dyn_cast(value)) { - InFlightDiagnostic diagnostic = targetClass.op->emitError( - "cannot rematerialize external block argument in materialized class body"); + InFlightDiagnostic diagnostic = + targetClass.op->emitError("cannot rematerialize external block argument in materialized class body"); diagnostic << " currentArg#" << blockArg.getArgNumber() << " currentType=" << blockArg.getType() << " targetClass=" << targetClass.id << " targetOp='" << targetClass.op->getName() << "'"; if (Operation* owner = blockArg.getOwner()->getParentOp()) { diagnostic << " ownerOp='" << owner->getName() << "'"; - diagnostic << " ownerIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(owner)) << "\""; + diagnostic << " ownerIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(owner)) + << "\""; diagnostic << " ownerChain=\"" << formatMaterializerParentChainInline(owner) << "\""; } - diagnostic << " targetIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(targetClass.op)) << "\""; + diagnostic << " targetIR=\"" + << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(targetClass.op)) << "\""; if (mapper) { diagnostic << " mapperPresent=1 mapperHadOriginal=" << (mapperHadOriginalValue ? 1 : 0); if (mapperHadOriginalValue) diagnostic << " mappedType=" << mappedOriginalValue.getType(); - } else { + } + else { diagnostic << " mapperPresent=0"; } attachMaterializerValueOriginNote(diagnostic, originalValue, "original value"); @@ -1820,9 +1821,11 @@ InFlightDiagnostic emitNonLocalMaterializedClassValueDiagnostic(Operation* ancho if (producer) { diagnostic << " from '" << producer->instance.op->getName() << "' resultIndex=" << producer->resultIndex << " laneStart=" << producer->instance.laneStart << " laneCount=" << producer->instance.laneCount; - } else if (auto result = dyn_cast(value)) { + } + else if (auto result = dyn_cast(value)) { diagnostic << " from '" << result.getOwner()->getName() << "' resultIndex=" << result.getResultNumber(); - } else if (auto blockArg = dyn_cast(value)) { + } + else if (auto blockArg = dyn_cast(value)) { diagnostic << " from block argument #" << blockArg.getArgNumber(); if (Operation* owner = blockArg.getOwner()->getParentOp()) diagnostic << " of '" << owner->getName() << "'"; @@ -1925,12 +1928,14 @@ FailureOr materializeTensorValueForMaterializedClassUse(MaterializerState if (mapper && mapper->contains(value)) value = mapper->lookup(value); - if (!isa(value.getType()) || isConstantLike(value) || isTensorValueLocalToMaterializedClass(value, targetClass)) + if (!isa(value.getType()) || isConstantLike(value) + || isTensorValueLocalToMaterializedClass(value, targetClass)) return value; if (value.getDefiningOp() || value.getDefiningOp() || value.getDefiningOp()) { - FailureOr rematerialized = rematerializeTensorValueInClass(state, targetClass, value, anchor, context, mapper); + FailureOr rematerialized = + rematerializeTensorValueInClass(state, targetClass, value, anchor, context, mapper); if (failed(rematerialized)) return failure(); return *rematerialized; @@ -1941,12 +1946,19 @@ FailureOr materializeTensorValueForMaterializedClassUse(MaterializerState return failure(); } + if (isOldComputeRegionBlockArgument(state, value)) { + InFlightDiagnostic diagnostic = + anchor->emitError("cannot append old graph_compute region block argument as scheduled input"); + attachMaterializerValueOriginNote(diagnostic, value, "escaped input"); + return failure(); + } + return appendInput(state, targetClass, value); } std::optional mapExternalRegionBlockArgumentToLocalClone(const MaterializedClass& targetClass, - Operation* anchor, - BlockArgument externalArg) { + Operation* anchor, + BlockArgument externalArg) { Block* sourceBlock = externalArg.getOwner(); Region* sourceRegion = sourceBlock ? sourceBlock->getParent() : nullptr; Operation* sourceParent = sourceRegion ? sourceRegion->getParentOp() : nullptr; @@ -1996,7 +2008,8 @@ FailureOr localizeMaterializedClassOperand(MaterializerState& state, return *localArg; if (isa(value.getType())) - return materializeTensorValueForMaterializedClassUse(state, targetClass, value, anchor, tensorContext, std::nullopt, mapper); + return materializeTensorValueForMaterializedClassUse( + state, targetClass, value, anchor, tensorContext, std::nullopt, mapper); if (isValueLegalInMaterializedClassBody(value, targetClass)) return value; @@ -2010,9 +2023,10 @@ FailureOr localizeMaterializedClassOperand(MaterializerState& state, diagnostic << " blockArg#" << blockArg.getArgNumber(); if (Operation* owner = blockArg.getOwner()->getParentOp()) diagnostic.attachNote(owner->getLoc()) << "block argument belongs to '" << owner->getName() << "'"; - } else if (Operation* definingOp = value.getDefiningOp()) { - diagnostic.attachNote(definingOp->getLoc()) << "unsupported external operand producer is '" << definingOp->getName() - << "'"; + } + else if (Operation* definingOp = value.getDefiningOp()) { + diagnostic.attachNote(definingOp->getLoc()) + << "unsupported external operand producer is '" << definingOp->getName() << "'"; } return failure(); } @@ -2069,8 +2083,7 @@ FailureOr createDim0ExtractSliceInClass(MaterializerState& state, "createDim0ExtractSliceInClass tried to reuse a tensor from another materialized class"); if (failed(localizedSource)) return failure(); - FailureOr localizedOffset = - rematerializeIndexOpFoldResultInClass(state, targetClass, firstOffset, loc); + FailureOr localizedOffset = rematerializeIndexOpFoldResultInClass(state, targetClass, firstOffset, loc); if (failed(localizedOffset)) return failure(); return createDim0ExtractSlice(state, loc, *localizedSource, *localizedOffset, firstSize); @@ -2118,8 +2131,7 @@ FailureOr createStaticExtractSliceInClass(MaterializerState& state, SmallVector localizedOffsets; localizedOffsets.reserve(sliceOffsets.size()); for (OpFoldResult offset : sliceOffsets) { - FailureOr localized = - rematerializeIndexOpFoldResultInClass(state, targetClass, offset, loc); + FailureOr localized = rematerializeIndexOpFoldResultInClass(state, targetClass, offset, loc); if (failed(localized)) return failure(); localizedOffsets.push_back(*localized); @@ -2135,11 +2147,12 @@ Value createIndexedIndexValue(MaterializerState& state, std::optional preferredPeriod = std::nullopt, bool allowExhaustiveTiledSearch = true); -FailureOr> buildProjectedFragmentOffsetsInClass(MaterializerState& state, - MaterializedClass& targetClass, - const ProjectedTransferDescriptor& descriptor, - Value flatFragmentIndex, - Location loc) { +FailureOr> +buildProjectedFragmentOffsetsInClass(MaterializerState& state, + MaterializedClass& targetClass, + const ProjectedTransferDescriptor& descriptor, + Value flatFragmentIndex, + Location loc) { FailureOr localizedIndex = rematerializeIndexValueInClass(state, targetClass, flatFragmentIndex, loc); if (failed(localizedIndex)) return failure(); @@ -2195,8 +2208,7 @@ FailureOr createDim0InsertSliceInClass(MaterializerState& state, "createDim0InsertSliceInClass tried to reuse a destination tensor from another materialized class"); if (failed(localizedDestination)) return failure(); - FailureOr localizedOffset = - rematerializeIndexOpFoldResultInClass(state, targetClass, firstOffset, loc); + FailureOr localizedOffset = rematerializeIndexOpFoldResultInClass(state, targetClass, firstOffset, loc); if (failed(localizedOffset)) return failure(); return createDim0InsertSlice(state, loc, *localizedFragment, *localizedDestination, *localizedOffset); @@ -2220,11 +2232,8 @@ Value scaleIndexByDim0Size(MaterializerState& state, Operation* anchor, Value in return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {index}, anchor); } -FailureOr scaleIndexByDim0SizeInClass(MaterializerState& state, - MaterializedClass& targetClass, - Value index, - int64_t dim0Size, - Location loc) { +FailureOr scaleIndexByDim0SizeInClass( + MaterializerState& state, MaterializedClass& targetClass, Value index, int64_t dim0Size, Location loc) { FailureOr localizedIndex = rematerializeIndexValueInClass(state, targetClass, index, loc); if (failed(localizedIndex)) return failure(); @@ -2307,12 +2316,8 @@ Value getPackedSliceForRunIndex(MaterializerState& state, return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0)); } -Value getPackedSliceForDynamicRunIndex(MaterializerState& state, - Operation* anchor, - Value packed, - RankedTensorType fragmentType, - Value index, - Location loc) { +Value getPackedSliceForDynamicRunIndex( + MaterializerState& state, Operation* anchor, Value packed, RankedTensorType fragmentType, Value index, Location loc) { Value firstOffset = scaleIndexByDim0Size(state, anchor, index, fragmentType.getDimSize(0), loc); return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0)); } @@ -2397,8 +2402,7 @@ std::optional AvailableValueStore::lookupPackedRun(MaterializerState& sta for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) { std::optional contiguousKey = getContiguousProducerRangeForKeys(slot.keys); auto keyIt = llvm::find(slot.keys, key); - if ((!contiguousKey && keyIt == slot.keys.end()) - || (contiguousKey && !containsProducerKey(*contiguousKey, key))) + if ((!contiguousKey && keyIt == slot.keys.end()) || (contiguousKey && !containsProducerKey(*contiguousKey, key))) continue; MaterializedClass& materializedClass = state.classes[classId]; @@ -2460,9 +2464,8 @@ IndexedBatchRunValue* AvailableValueStore::lookupIndexedBatchRun(ProducerKey key std::optional AvailableValueStore::lookup(MaterializerState& state, ProducerKey key, ClassId classId) { - if (std::optional exact = lookupExact(key, classId)) { + if (std::optional exact = lookupExact(key, classId)) return exact; - } if (std::optional packedRunValue = lookupPackedRun(state, key, classId)) return packedRunValue; @@ -2608,9 +2611,8 @@ Value createIndexedIndexValue(MaterializerState& state, bool allowExhaustiveTiledSearch) { assert(!values.empty() && "expected at least one indexed value"); - if (allEqual(values)) { + if (allEqual(values)) return getOrCreateIndexConstant(state.constantFolder, anchor, values.front()); - } if (std::optional pattern = getIndexedIndexPattern(values, preferredPeriod, allowExhaustiveTiledSearch)) @@ -2631,11 +2633,8 @@ Value createIndexedIndexValue( return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, std::nullopt, true); } -OpFoldResult createIndexedOrStaticIndex(MaterializerState& state, - Operation* anchor, - ArrayRef values, - Value index, - Location loc) { +OpFoldResult createIndexedOrStaticIndex( + MaterializerState& state, Operation* anchor, ArrayRef values, Value index, Location loc) { assert(!values.empty() && "expected at least one indexed value"); if (allEqual(values)) return state.rewriter.getIndexAttr(values.front()); @@ -2801,8 +2800,7 @@ bool hasRealComputeConsumer(Value value, const DenseSet& oldComputeO return false; } -FailureOr -getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex); +FailureOr getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex); bool isTerminalHostBatchOutput(Value output, const DenseSet& oldComputeOps) { auto batch = dyn_cast_or_null(output.getDefiningOp()); @@ -2813,7 +2811,6 @@ bool isTerminalHostBatchOutput(Value output, const DenseSet& oldComp return !hasRealComputeConsumer(output, oldComputeOps); } - void appendDestinationClass(MaterializerState& state, ProducerKey key, ClassId classId) { SmallVector& destinations = state.producerDestClasses[key]; if (!llvm::is_contained(destinations, classId)) @@ -2856,7 +2853,7 @@ LogicalResult collectProducerDestinations(MaterializerState& state) { ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second); if (sourceClass == targetClass) { - SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, targetClass}; + SameClassConsumerLookupKey lookupKey {producerKey.instance.op, producerKey.resultIndex, targetClass}; SmallVector& bucket = state.sameClassConsumerIndex[lookupKey]; if (!llvm::is_contained(bucket, producerKey)) bucket.push_back(producerKey); @@ -2892,7 +2889,6 @@ bool isStaticSliceInBounds(ArrayRef offsets, RankedTensorType sourceTyp return true; } - bool isStaticSliceContainedIn(ArrayRef innerOffsets, ArrayRef innerSizes, ArrayRef outerOffsets, @@ -3228,14 +3224,8 @@ getProjectedWholeBatchReplacementProducer(MaterializerState& state, SpatComputeB if (!wholeBatchProducer) return std::nullopt; - if (canUseProjectedLaneInput( - state, - batch, - inputIndex, - input, - ComputeInstance {batch.getOperation(), 0, 1})) { + if (canUseProjectedLaneInput(state, batch, inputIndex, input, ComputeInstance {batch.getOperation(), 0, 1})) return std::nullopt; - } auto producerBatch = dyn_cast_or_null(wholeBatchProducer->instance.op); if (!producerBatch) @@ -3247,8 +3237,8 @@ getProjectedWholeBatchReplacementProducer(MaterializerState& state, SpatComputeB return wholeBatchProducer; } -std::optional -getProjectedWholeBatchReplacementProducer(MaterializerState& state, tensor::ExtractSliceOp extract) { +std::optional getProjectedWholeBatchReplacementProducer(MaterializerState& state, + tensor::ExtractSliceOp extract) { auto sourceArg = dyn_cast(extract.getSource()); if (!sourceArg) return std::nullopt; @@ -3313,8 +3303,7 @@ FailureOr evaluateProjectionIndexLike(OpFoldResult value, Value laneArg return evaluateProjectionIndexLike(llvm::cast(value), laneArg, lane); } -FailureOr -getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex) { +FailureOr getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex) { auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); if (!inParallel) return failure(); @@ -3353,11 +3342,10 @@ evaluateStaticProjectionIndices(ArrayRef values, Value laneArg, ui return evaluated; } - bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch, - const AffineProjectedInputSliceMatch& match, - ProducerKey producer, - uint32_t consumerLane) { + const AffineProjectedInputSliceMatch& match, + ProducerKey producer, + uint32_t consumerLane) { auto producerBatch = dyn_cast_or_null(producer.instance.op); if (!producerBatch) return true; @@ -3423,7 +3411,6 @@ bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consu return recurse(recurse, 0); } - LogicalResult collectProjectedTransfers(MaterializerState& state) { struct PendingProjectedTransferDescriptor { ProjectedBatchInputKey inputKey; @@ -3836,7 +3823,8 @@ FailureOr buildProjectedPackedPayload(MaterializerState& state, .getResult(); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, descriptor.layout.payloadFragmentCount); + Value upperBound = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, descriptor.layout.payloadFragmentCount); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); auto loop = buildNormalizedScfFor( @@ -3865,8 +3853,8 @@ FailureOr buildProjectedPackedPayload(MaterializerState& state, if (failed(fragment)) return failure(); - FailureOr packedOffset = - scaleIndexByDim0SizeInClass(state, targetClass, fragmentIndex, descriptor.layout.fragmentType.getDimSize(0), loc); + FailureOr packedOffset = scaleIndexByDim0SizeInClass( + state, targetClass, fragmentIndex, descriptor.layout.fragmentType.getDimSize(0), loc); if (failed(packedOffset)) return failure(); FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, *fragment, acc, *packedOffset); @@ -3976,7 +3964,7 @@ LogicalResult appendSend(MaterializerState& state, state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); if (messages.size() != sourceClass.cpus.size()) return sourceClass.op->emitError("batch send expects exactly one message per materialized lane") - << " messageCount=" << messages.size() << " laneCount=" << sourceClass.cpus.size(); + << " messageCount=" << messages.size() << " laneCount=" << sourceClass.cpus.size(); Value channelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc); Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc); @@ -4358,8 +4346,7 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, if (failed(targetCpu)) return failure(); if (keys.size() != sourceClass.cpus.size()) - return sourceClass.op->emitError( - "batch-to-scalar communication expects one producer key per source lane"); + return sourceClass.op->emitError("batch-to-scalar communication expects one producer key per source lane"); for (CpuId sourceCpu : sourceClass.cpus) { auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch-to-scalar source core id"); if (failed(checkedSourceCpu)) @@ -4413,10 +4400,10 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput); if (resultIt == sourceClass.hostOutputToResultIndex.end()) return sourceClass.op->emitError("missing host result slot for materialized output") - << " ownerKind=" << (sourceClass.isBatch ? "batch" : "scalar") - << " hostOutputs=" << sourceClass.hostOutputs.size() - << " originalDef=" << (originalOutput.getDefiningOp() ? originalOutput.getDefiningOp()->getName().getStringRef() - : StringRef("")); + << " ownerKind=" << (sourceClass.isBatch ? "batch" : "scalar") + << " hostOutputs=" << sourceClass.hostOutputs.size() << " originalDef=" + << (originalOutput.getDefiningOp() ? originalOutput.getDefiningOp()->getName().getStringRef() + : StringRef("")); unsigned resultIndex = resultIt->second; state.hostReplacements[originalOutput] = sourceClass.op->getResult(resultIndex); @@ -4429,7 +4416,7 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val return sourceClass.op->emitError("host result index out of range for materialized compute"); if (payload.getType() != originalOutput.getType()) return sourceClass.op->emitError("cannot set scalar host output from fragment payload") - << " payloadType=" << payload.getType() << " outputType=" << originalOutput.getType(); + << " payloadType=" << payload.getType() << " outputType=" << originalOutput.getType(); state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(resultIndex, payload); }); return success(); @@ -4505,7 +4492,7 @@ emitHostCommunication(MaterializerState& state, MaterializedClass& sourceClass, return ownerClass.op->emitError("generic host publication does not support batch host owners"); if (payload.getType() != originalOutput.getType()) return sourceClass.op->emitError("cannot forward fragment payload to scalar host owner") - << " payloadType=" << payload.getType() << " outputType=" << originalOutput.getType(); + << " payloadType=" << payload.getType() << " outputType=" << originalOutput.getType(); MessageVector messages; auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "host source core id"); @@ -4521,11 +4508,11 @@ emitHostCommunication(MaterializerState& state, MaterializedClass& sourceClass, } LogicalResult emitOutputFanout(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Value payload, - Value originalOutput, - Location loc) { + MaterializedClass& sourceClass, + ArrayRef keys, + Value payload, + Value originalOutput, + Location loc) { if (keys.empty()) return success(); @@ -4554,8 +4541,7 @@ LogicalResult emitOutputFanout(MaterializerState& state, return failure(); if (hasLiveExternalUseCached(state, originalOutput) && !recordedProjectedHostFragments) - return sourceClass.op->emitError( - "batch host publication requires explicit fragment assembly metadata"); + return sourceClass.op->emitError("batch host publication requires explicit fragment assembly metadata"); if (!recordedProjectedHostFragments && failed(emitHostCommunication(state, sourceClass, payload, originalOutput))) return failure(); @@ -4709,11 +4695,11 @@ FailureOr insertFragmentIntoWholeBatch(MaterializerState& state, } FailureOr extractPackedSlotForIndex(MaterializerState& state, - MaterializedClass& targetClass, - Value packed, - RankedTensorType slotPackedType, - Value slotIndex, - Location loc) { + MaterializedClass& targetClass, + Value packed, + RankedTensorType slotPackedType, + Value slotIndex, + Location loc) { FailureOr firstOffset = scaleIndexByDim0SizeInClass(state, targetClass, slotIndex, slotPackedType.getDimSize(0), loc); if (failed(firstOffset)) @@ -4742,7 +4728,6 @@ bool packedScalarRunSlotsMatch(const PackedScalarRunValue& lhs, const PackedScal return true; } - std::optional getConstantIndexValue(Value value) { APInt constant; if (matchPattern(value, m_ConstantInt(&constant))) @@ -4763,7 +4748,8 @@ bool appendConstantChannelReceiveMessage(MessageVector& messages, SpatChannelRec PackedScalarRunValue* findDeferredReceiveAlternativeForPackedRun(MaterializerState& state, const MaterializedClass& targetClass, const PackedScalarRunValue& run) { - WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(run.sourceOp, run.resultIndex, targetClass.id); + WholeBatchAssemblyLookupKey lookupKey = + makeWholeBatchAssemblyLookupKey(run.sourceOp, run.resultIndex, targetClass.id); ArrayRef runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey); for (size_t runIndex : runIndices) { @@ -4889,7 +4875,8 @@ FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& scaleIndexByDim0SizeInClass(state, targetClass, loopIndex, run.fragmentType.getDimSize(0), loc); if (failed(firstOffset)) return failure(); - FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, produced->front(), acc, *firstOffset); + FailureOr next = + createDim0InsertSliceInClass(state, targetClass, loc, produced->front(), acc, *firstOffset); if (failed(next)) return failure(); yielded.push_back(*next); @@ -4956,9 +4943,8 @@ LogicalResult collectDirectFragmentsForWholeBatchInput(MaterializerState& state, }; uint32_t batchLaneCount = static_cast(batch.getLaneCount()); - if (plan.coveredLaneCount == plan.batchLaneCount) { + if (plan.coveredLaneCount == plan.batchLaneCount) return success(); - } WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id); ArrayRef indexedFragments = @@ -4996,9 +4982,8 @@ LogicalResult collectDirectFragmentsForWholeBatchInput(MaterializerState& state, size_t candidateCursor = 0; uint32_t lane = 0; while (lane < batchLaneCount) { - while (lane < batchLaneCount && wholeBatchLaneCovered(plan, lane)) { + while (lane < batchLaneCount && wholeBatchLaneCovered(plan, lane)) ++lane; - } if (lane >= batchLaneCount) break; @@ -5302,12 +5287,13 @@ FailureOr emitWholeBatchFragmentGroup(MaterializerState& state, "whole-batch direct fragment assembly tried to reuse a tensor from another materialized class"); if (failed(localFragment)) return failure(); - FailureOr updated = createDim0InsertSliceInClass(state, - targetClass, - loc, - *localFragment, - destination, - getOrCreateIndexConstant(state.constantFolder, targetClass.op, offset)); + FailureOr updated = + createDim0InsertSliceInClass(state, + targetClass, + loc, + *localFragment, + destination, + getOrCreateIndexConstant(state.constantFolder, targetClass.op, offset)); if (failed(updated)) return failure(); destination = *updated; @@ -5318,19 +5304,17 @@ FailureOr emitWholeBatchFragmentGroup(MaterializerState& state, return failure(); } -FailureOr emitProjectedWholeBatchFragmentInsertLoop( - MaterializerState& state, - MaterializedClass& targetClass, - Value destination, - const ProjectedWholeBatchFragmentGroup& group, - llvm::function_ref(Value)> buildFragment, - Location loc) { +FailureOr emitProjectedWholeBatchFragmentInsertLoop(MaterializerState& state, + MaterializedClass& targetClass, + Value destination, + const ProjectedWholeBatchFragmentGroup& group, + llvm::function_ref(Value)> buildFragment, + Location loc) { assert(group.fragmentType && "expected projected fragment type"); assert(!group.offsetsByDim.empty() && "expected projected insert coordinates"); Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); - Value upperBound = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, group.offsetsByDim.front().size()); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, group.offsetsByDim.front().size()); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); @@ -5472,11 +5456,8 @@ FailureOr emitWholeBatchAssemblyPlan(MaterializerState& state, // Run materialization helpers. // ----------------------------------------------------------------------------- -FailureOr materializeProjectedWholeBatchInputFromFragments(MaterializerState& state, - MaterializedClass& targetClass, - ProducerKey key, - Type resultType, - Location loc) { +FailureOr materializeProjectedWholeBatchInputFromFragments( + MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, Type resultType, Location loc) { auto batch = dyn_cast_or_null(key.instance.op); auto resultTensorType = dyn_cast(resultType); if (!batch || !resultTensorType || !resultTensorType.hasStaticShape()) @@ -5538,9 +5519,9 @@ FailureOr materializeProjectedWholeBatchInputFromFragments(MaterializerSt return failure(); state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - result = tensor::InsertSliceOp::create( - state.rewriter, loc, *localFragment, result, offsetAttrs, sizeAttrs, strideAttrs) - .getResult(); + result = + tensor::InsertSliceOp::create(state.rewriter, loc, *localFragment, result, offsetAttrs, sizeAttrs, strideAttrs) + .getResult(); } state.availableValues.record(key, targetClass.id, result); @@ -5685,8 +5666,8 @@ FailureOr materializeProjectedWholeBatchInputFromFragments(MaterializerSt extractOffsets.reserve(group.packedSourceType.getRank()); extractSizes.reserve(group.packedSourceType.getRank()); extractStrides.reserve(group.packedSourceType.getRank()); - extractOffsets.push_back(createIndexedOrStaticIndex( - state, targetClass.op, group.packedIndices, flatIndex, loc)); + extractOffsets.push_back( + createIndexedOrStaticIndex(state, targetClass.op, group.packedIndices, flatIndex, loc)); extractSizes.push_back(state.rewriter.getIndexAttr(1)); extractStrides.push_back(state.rewriter.getIndexAttr(1)); for (int64_t dim = 1; dim < group.packedSourceType.getRank(); ++dim) { @@ -5705,13 +5686,7 @@ FailureOr materializeProjectedWholeBatchInputFromFragments(MaterializerSt return failure(); return tensor::ExtractSliceOp::create( - state.rewriter, - loc, - group.fragmentType, - *packed, - extractOffsets, - extractSizes, - extractStrides) + state.rewriter, loc, group.fragmentType, *packed, extractOffsets, extractSizes, extractStrides) .getResult(); }, loc); @@ -5824,8 +5799,7 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat FailureOr> strides = evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, peer.laneStart); if (failed(offsets) || failed(sizes) || failed(strides)) { - sourceClass.op->emitError("failed to evaluate projected host output slice for logical lane ") - << peer.laneStart; + sourceClass.op->emitError("failed to evaluate projected host output slice for logical lane ") << peer.laneStart; return failure(); } @@ -5840,7 +5814,8 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat SmallVector(*sizes), SmallVector(*strides), peer.laneStart, - loc}); + loc + }); } return true; @@ -5880,7 +5855,7 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, keys.front().instance.laneStart); if (failed(firstSizes)) return sourceClass.op->emitError("failed to evaluate projected host output slice for logical lane ") - << keys.front().instance.laneStart; + << keys.front().instance.laneStart; SmallVector fragmentShape(*firstSizes); auto fragmentType = RankedTensorType::get(fragmentShape, packedType.getElementType(), packedType.getEncoding()); @@ -5895,24 +5870,23 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt size_t keysPerPublishedPayload = keys.size(); if (sourceClass.isBatch) { if (sourceClass.cpus.empty() || keys.size() % sourceClass.cpus.size() != 0) - return sourceClass.op->emitError( - "projected packed host publication requires a stable per-lane key partition") - << " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size() - << " laneCount=" << sourceClass.cpus.size(); + return sourceClass.op->emitError("projected packed host publication requires a stable per-lane key partition") + << " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size() + << " laneCount=" << sourceClass.cpus.size(); keysPerPublishedPayload = keys.size() / sourceClass.cpus.size(); } if (packedType.getRank() == 0 || packedType.getDimSize(0) % static_cast(keysPerPublishedPayload) != 0) - return sourceClass.op->emitError( - "projected packed host publication requires either direct fragment operands or evenly dim-0 packed fragments") - << " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size(); + return sourceClass.op->emitError("projected packed host publication requires either direct fragment operands or " + "evenly dim-0 packed fragments") + << " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size(); SmallVector packedFragmentShape(packedType.getShape()); packedFragmentShape[0] /= static_cast(keysPerPublishedPayload); if (packedFragmentShape != fragmentShape) return sourceClass.op->emitError( "projected packed host publication fragment shape does not match projected slice size") - << " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size(); + << " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size(); } int64_t payloadElementCount = packedType.getNumElements(); @@ -5925,12 +5899,14 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt || static_cast(keysPerPublishedPayload) % fragmentsPerPublishedPayload != 0) return sourceClass.op->emitOpError( "projected packed host publication requires a deterministic publication packing layout") - << " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size(); + << " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size(); DenseMap publishedFragmentOrdinals; for (auto [fragmentIndex, key] : llvm::enumerate(keys)) { - if (key.instance.op != sourceBatch.getOperation() || key.resultIndex != keys.front().resultIndex || key.instance.laneCount != 1) - return sourceClass.op->emitError("projected packed host publication requires one-lane keys from one producer result"); + if (key.instance.op != sourceBatch.getOperation() || key.resultIndex != keys.front().resultIndex + || key.instance.laneCount != 1) + return sourceClass.op->emitError( + "projected packed host publication requires one-lane keys from one producer result"); FailureOr> offsets = evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, key.instance.laneStart); @@ -5940,10 +5916,10 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, key.instance.laneStart); if (failed(offsets) || failed(sizes) || failed(strides)) return sourceClass.op->emitError("failed to evaluate projected host output slice for logical lane ") - << key.instance.laneStart; + << key.instance.laneStart; if (SmallVector(*sizes) != fragmentShape) return sourceClass.op->emitError( - "projected packed host publication requires one operand to map to a consistent fragment shape"); + "projected packed host publication requires one operand to map to a consistent fragment shape"); FailureOr publishedLaneIndex = getPublicationLaneForProducerKey(state, sourceClass, key); if (failed(publishedLaneIndex)) @@ -5996,8 +5972,8 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { for (Value originalOutput : outputs) { if (isa_and_present(originalOutput.getDefiningOp())) { - return state.func.emitError( - "projected host output assembly must be keyed by the original logical host output, not by a materialized scheduled result"); + return state.func.emitError("projected host output assembly must be keyed by the original logical host output, " + "not by a materialized scheduled result"); } auto resultType = dyn_cast(originalOutput.getType()); @@ -6005,19 +5981,17 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { return state.func.emitError("projected host output must have static ranked tensor type"); SmallVector& fragments = byOutput[originalOutput]; - llvm::sort(fragments, [](const PendingProjectedHostOutputFragment* lhs, - const PendingProjectedHostOutputFragment* rhs) { - if (lhs->sourceClass != rhs->sourceClass) - return lhs->sourceClass < rhs->sourceClass; - if (lhs->publicationResultIndex != rhs->publicationResultIndex) - return lhs->publicationResultIndex < rhs->publicationResultIndex; - if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal) - return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal; - return std::lexicographical_compare(lhs->offsets.begin(), - lhs->offsets.end(), - rhs->offsets.begin(), - rhs->offsets.end()); - }); + llvm::sort(fragments, + [](const PendingProjectedHostOutputFragment* lhs, const PendingProjectedHostOutputFragment* rhs) { + if (lhs->sourceClass != rhs->sourceClass) + return lhs->sourceClass < rhs->sourceClass; + if (lhs->publicationResultIndex != rhs->publicationResultIndex) + return lhs->publicationResultIndex < rhs->publicationResultIndex; + if (lhs->sourceFragmentOrdinal != rhs->sourceFragmentOrdinal) + return lhs->sourceFragmentOrdinal < rhs->sourceFragmentOrdinal; + return std::lexicographical_compare( + lhs->offsets.begin(), lhs->offsets.end(), rhs->offsets.begin(), rhs->offsets.end()); + }); state.rewriter.setInsertionPoint(returnOp); Location loc = fragments.front()->loc; @@ -6036,9 +6010,8 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass]; if (fragmentRecord->publicationResultIndex >= sourceClass.op->getNumResults()) { return sourceClass.op->emitError("projected host output fragment references an invalid publication result") - << " sourceClass=" << sourceClass.id - << " resultIndex=" << fragmentRecord->publicationResultIndex - << " resultCount=" << sourceClass.op->getNumResults(); + << " sourceClass=" << sourceClass.id << " resultIndex=" << fragmentRecord->publicationResultIndex + << " resultCount=" << sourceClass.op->getNumResults(); } Value operand = sourceClass.op->getResult(fragmentRecord->publicationResultIndex); @@ -6063,23 +6036,22 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { Value input = blueprintOperands.front(); ValueRange extraFragments = ValueRange(blueprintOperands).drop_front(); - auto blueprint = spatial::SpatBlueprintOp::create( - state.rewriter, - loc, - resultType, - input, - extraFragments, - state.rewriter.getStringAttr("nchw"), - state.rewriter.getStringAttr("fragmented"), - state.rewriter.getDenseI64ArrayAttr(flatOffsets), - state.rewriter.getDenseI64ArrayAttr(flatSizes), - state.rewriter.getStringAttr("identity"), - state.rewriter.getStringAttr("fragment_assembly"), - state.rewriter.getDenseI64ArrayAttr(fragmentOperandIndices), - state.rewriter.getDenseI64ArrayAttr(fragmentSourceOffsets), - state.rewriter.getDenseI64ArrayAttr(flatStrides), - state.rewriter.getStringAttr("disjoint"), - state.rewriter.getStringAttr("complete")); + auto blueprint = spatial::SpatBlueprintOp::create(state.rewriter, + loc, + resultType, + input, + extraFragments, + state.rewriter.getStringAttr("nchw"), + state.rewriter.getStringAttr("fragmented"), + state.rewriter.getDenseI64ArrayAttr(flatOffsets), + state.rewriter.getDenseI64ArrayAttr(flatSizes), + state.rewriter.getStringAttr("identity"), + state.rewriter.getStringAttr("fragment_assembly"), + state.rewriter.getDenseI64ArrayAttr(fragmentOperandIndices), + state.rewriter.getDenseI64ArrayAttr(fragmentSourceOffsets), + state.rewriter.getDenseI64ArrayAttr(flatStrides), + state.rewriter.getStringAttr("disjoint"), + state.rewriter.getStringAttr("complete")); state.hostReplacements[originalOutput] = blueprint.getOutput(); } @@ -6098,11 +6070,12 @@ FailureOr resolveInputValue(MaterializerState& state, return resolved; std::optional producer = getInputRequestProducerKey(input, consumerInstance); - emitNonLocalMaterializedClassValueDiagnostic(consumerInstance.op, - targetClass, - "input resolution tried to reuse a tensor from another materialized class", - resolved, - producer); + emitNonLocalMaterializedClassValueDiagnostic( + consumerInstance.op, + targetClass, + "input resolution tried to reuse a tensor from another materialized class", + resolved, + producer); return failure(); }; @@ -6123,7 +6096,6 @@ FailureOr resolveInputValue(MaterializerState& state, if (std::optional value = state.availableValues.lookup(state, *producer, targetClass.id)) return rejectNonLocalResolvedValue(*value); - if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) { size_t laneCount = targetClass.cpus.size(); for (auto [slotIndex, slot] : llvm::enumerate(indexedRun->slots)) { @@ -6133,9 +6105,14 @@ FailureOr resolveInputValue(MaterializerState& state, Value received = Value(); if (indexedRun->packed) { state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - received = getPackedSliceForRunIndex( - state, targetClass.op, indexedRun->packed, indexedRun->fragmentType, slotIndex, consumerInstance.op->getLoc()); - } else { + received = getPackedSliceForRunIndex(state, + targetClass.op, + indexedRun->packed, + indexedRun->fragmentType, + slotIndex, + consumerInstance.op->getLoc()); + } + else { MessageVector messages = indexedRun->messages.slice(slotIndex * laneCount, laneCount); received = appendReceive(state, targetClass, indexedRun->fragmentType, messages, consumerInstance.op->getLoc()); @@ -6179,12 +6156,45 @@ FailureOr resolveInputValue(MaterializerState& state, return failure(); } + if (isOldComputeRegionBlockArgument(state, input)) { + InFlightDiagnostic diagnostic = + consumerInstance.op->emitError("cannot append old graph_compute region block argument as scheduled input"); + attachMaterializerValueOriginNote(diagnostic, input, "escaped input"); + return failure(); + } + return appendInput(state, targetClass, input); } +bool hasPlannedProjectedInputTransfer(MaterializerState& state, + SpatComputeBatch batch, + unsigned inputIndex, + Value input, + ComputeInstance logicalConsumer, + ClassId classId) { + ProjectedBatchInputKey inputKey {batch.getOperation(), inputIndex}; + SmallVector producers = collectProducerKeysForDestinations(input, logicalConsumer); + for (ProducerKey producer : producers) { + auto producerIt = state.projectedTransfers.find(producer); + if (producerIt == state.projectedTransfers.end()) + continue; + + auto descriptorIt = producerIt->second.find(classId); + if (descriptorIt == producerIt->second.end()) + continue; + + if (descriptorIt->second.inputKey == inputKey) + return true; + } + + return false; +} + bool hasProjectedInputReplacement(MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex, + Value input, + ComputeInstance logicalConsumer, ClassId classId) { std::optional match = getProjectedInputSliceMatch(state, batch, inputIndex); if (!match) @@ -6195,6 +6205,9 @@ bool hasProjectedInputReplacement(MaterializerState& state, && replacementIt->second.find(classId) != replacementIt->second.end()) return true; + if (hasPlannedProjectedInputTransfer(state, batch, inputIndex, input, logicalConsumer, classId)) + return true; + return getProjectedWholeBatchReplacementProducer(state, batch, inputIndex).has_value(); } @@ -6252,11 +6265,12 @@ LogicalResult mapInputs(MaterializerState& state, return compute.emitOpError("expected compute input block argument while materializing inputs"); FailureOr remapped = mapResolvedInput(*mapped); if (failed(remapped)) { - emitNonLocalMaterializedClassValueDiagnostic(compute, - targetClass, - "mapInputs tried to append a tensor from another materialized class", - *mapped, - getInputRequestProducerKey(input, instance)); + emitNonLocalMaterializedClassValueDiagnostic( + compute, + targetClass, + "mapInputs tried to append a tensor from another materialized class", + *mapped, + getInputRequestProducerKey(input, instance)); return failure(); } mapper.map(*inputArg, *remapped); @@ -6266,7 +6280,8 @@ LogicalResult mapInputs(MaterializerState& state, auto batch = cast(op); for (auto [index, input] : llvm::enumerate(batch.getInputs())) { - if (hasProjectedInputReplacement(state, batch, static_cast(index), targetClass.id)) + if (hasProjectedInputReplacement( + state, batch, static_cast(index), input, instance, targetClass.id)) continue; FailureOr demand = @@ -6281,13 +6296,13 @@ LogicalResult mapInputs(MaterializerState& state, state, targetClass, *demand->wholeTensorProducer, input.getType(), batch.getOperation()->getLoc()); if (failed(mapped)) return batch.emitOpError("failed to materialize whole-batch compute_batch input") - << " #" << index << " from '" << demand->wholeTensorProducer->instance.op->getName() - << "' laneStart=" << demand->wholeTensorProducer->instance.laneStart - << " laneCount=" << demand->wholeTensorProducer->instance.laneCount - << " resultIndex=" << demand->wholeTensorProducer->resultIndex; - } else { - mapped = resolveInputValue( - state, targetClass, input, instance, indexing, /*allowWholeBatchFallback=*/false); + << " #" << index << " from '" << demand->wholeTensorProducer->instance.op->getName() + << "' laneStart=" << demand->wholeTensorProducer->instance.laneStart + << " laneCount=" << demand->wholeTensorProducer->instance.laneCount + << " resultIndex=" << demand->wholeTensorProducer->resultIndex; + } + else { + mapped = resolveInputValue(state, targetClass, input, instance, indexing, /*allowWholeBatchFallback=*/false); if (failed(mapped)) return batch.emitOpError("failed to resolve materialized compute_batch input"); } @@ -6405,9 +6420,8 @@ FailureOr materializeProjectedWholeBatchExtractReplacement(MaterializerSt auto remapFoldResult = [&](OpFoldResult value) -> FailureOr { if (auto mappedValue = dyn_cast_if_present(value)) { - FailureOr localized = - rematerializeIndexValueInClass(state, targetClass, mapper ? mapper->lookupOrDefault(mappedValue) : mappedValue, - extract.getLoc(), mapper); + FailureOr localized = rematerializeIndexValueInClass( + state, targetClass, mapper ? mapper->lookupOrDefault(mappedValue) : mappedValue, extract.getLoc(), mapper); if (failed(localized)) return failure(); return OpFoldResult(*localized); @@ -6470,15 +6484,14 @@ LogicalResult applyProjectedExtractReplacementsInClonedOp(MaterializerState& sta return success(); } - if (std::optional producer = - getProjectedWholeBatchReplacementProducer(state, originalExtract)) { + if (std::optional producer = getProjectedWholeBatchReplacementProducer(state, originalExtract)) { auto clonedExtract = dyn_cast(&clonedOp); if (!clonedExtract) return targetClass.op->emitError("projected whole-batch replacement lost extract structure during cloning"); state.rewriter.setInsertionPoint(clonedExtract); - FailureOr projected = materializeProjectedWholeBatchExtractReplacement( - state, targetClass, clonedExtract, *producer, &mapper); + FailureOr projected = + materializeProjectedWholeBatchExtractReplacement(state, targetClass, clonedExtract, *producer, &mapper); if (failed(projected)) return failure(); @@ -6531,7 +6544,8 @@ LogicalResult mapClonedRegionBlockArguments(Operation& originalOp, Operation& cl if (!mapper.contains(originalArg)) mapper.map(originalArg, clonedArg); - if (std::distance(originalBlock.begin(), originalBlock.end()) != std::distance(clonedBlock.begin(), clonedBlock.end())) + if (std::distance(originalBlock.begin(), originalBlock.end()) + != std::distance(clonedBlock.begin(), clonedBlock.end())) return clonedOp.emitError("cloned operation block has a different number of operations than the source block"); auto originalIt = originalBlock.begin(); @@ -6548,6 +6562,69 @@ LogicalResult mapClonedRegionBlockArguments(Operation& originalOp, Operation& cl return success(); } +static std::optional getConstantIndex(OpFoldResult value); +bool isStaticSliceInBounds(ArrayRef offsets, RankedTensorType sourceType, RankedTensorType fragmentType); + +FailureOr> tryNormalizeLocalizedExtractSlice(tensor::ExtractSliceOp extract, IRMapping& mapper) { + Value localizedSource = mapper.lookupOrDefault(extract.getSource()); + if (localizedSource == extract.getSource() || localizedSource.getType() == extract.getSource().getType()) + return std::optional {}; + + auto localizedSourceType = dyn_cast(localizedSource.getType()); + auto resultType = dyn_cast(extract.getType()); + if (!localizedSourceType || !resultType || !localizedSourceType.hasStaticShape() || !resultType.hasStaticShape()) + return std::optional {}; + if (localizedSourceType != resultType) + return std::optional {}; + + if (extract.getMixedSizes().size() != static_cast(localizedSourceType.getRank()) + || extract.getMixedStrides().size() != static_cast(localizedSourceType.getRank())) + return std::optional {}; + + for (int64_t dim = 0; dim < localizedSourceType.getRank(); ++dim) { + std::optional size = getConstantIndex(extract.getMixedSizes()[dim]); + std::optional stride = getConstantIndex(extract.getMixedStrides()[dim]); + if (!size || !stride || *stride != 1 || *size != localizedSourceType.getDimSize(dim)) + return std::optional {}; + } + + return std::optional {localizedSource}; +} + +LogicalResult verifyMaterializedStaticExtractSlice(Operation* anchor, + Value source, + RankedTensorType resultType, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + auto sourceType = dyn_cast(source.getType()); + if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()) + return success(); + if (offsets.size() != static_cast(sourceType.getRank()) + || sizes.size() != static_cast(sourceType.getRank()) + || strides.size() != static_cast(sourceType.getRank())) + return success(); + + SmallVector staticOffsets; + staticOffsets.reserve(offsets.size()); + for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { + std::optional offset = getConstantIndex(offsets[dim]); + std::optional size = getConstantIndex(sizes[dim]); + std::optional stride = getConstantIndex(strides[dim]); + if (!offset || !size || !stride) + return success(); + if (*stride != 1 || *size != resultType.getDimSize(dim)) + return success(); + staticOffsets.push_back(*offset); + } + + if (isStaticSliceInBounds(staticOffsets, sourceType, resultType)) + return success(); + + return anchor->emitError("materializer produced statically out-of-bounds extract_slice") + << " sourceType=" << sourceType << " resultType=" << resultType; +} + LogicalResult cloneComputeTemplateBody(MaterializerState& state, MaterializedClass& targetClass, const ComputeInstance& instance, @@ -6597,6 +6674,14 @@ LogicalResult cloneComputeTemplateBody(MaterializerState& state, } if (auto extract = dyn_cast(&op)) { + FailureOr> normalized = tryNormalizeLocalizedExtractSlice(extract, mapper); + if (failed(normalized)) + return failure(); + if (*normalized) { + mapper.map(extract.getResult(), **normalized); + continue; + } + auto remapFoldResult = [&](OpFoldResult value) -> OpFoldResult { if (auto mappedValue = dyn_cast_if_present(value)) return mapper.lookupOrDefault(mappedValue); @@ -6616,6 +6701,9 @@ LogicalResult cloneComputeTemplateBody(MaterializerState& state, auto resultType = cast(extract.getType()); Value localizedSource = mapper.lookupOrDefault(extract.getSource()); + if (failed(verifyMaterializedStaticExtractSlice( + extract.getOperation(), localizedSource, resultType, offsets, sizes, strides))) + return failure(); Value localizedExtract = extractStaticSliceOrIdentity( state.rewriter, extract.getLoc(), localizedSource, resultType, offsets, sizes, strides); mapper.map(extract.getResult(), localizedExtract); @@ -6625,11 +6713,11 @@ LogicalResult cloneComputeTemplateBody(MaterializerState& state, Operation* cloned = state.rewriter.clone(op, mapper); if (failed(mapClonedRegionBlockArguments(op, *cloned, mapper))) return failure(); - if (failed(localizeCapturesInClonedOp(state, targetClass, *cloned, &mapper))) - return failure(); if (op.getNumRegions() != 0 && failed(applyProjectedExtractReplacementsInClonedOp(state, targetClass, op, *cloned, indexing, mapper))) return failure(); + if (failed(localizeCapturesInClonedOp(state, targetClass, *cloned, &mapper))) + return failure(); for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults())) mapper.map(oldResult, newResult); } @@ -6792,11 +6880,10 @@ LogicalResult localizeCapturesInOperationTree(MaterializerState& state, FailureOr localized = localizeMaterializedClassOperand(state, targetClass, current, nestedOp, tensorContext, genericContext, mapper); if (failed(localized)) { - InFlightDiagnostic diagnostic = targetClass.op->emitError( - "failed to localize cloned scheduled-body operand"); - diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName() - << "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType() - << " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp)) + InFlightDiagnostic diagnostic = targetClass.op->emitError("failed to localize cloned scheduled-body operand"); + diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName() << "' operand#" + << operand.getOperandNumber() << " operandType=" << current.getType() << " offendingIR=\"" + << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp)) << "\" offendingOperands=\"" << formatMaterializerOperandListInline(nestedOp, targetClass) << "\" parentChain=\"" << formatMaterializerParentChainInline(nestedOp) << "\""; diagnostic.attachNote(nestedOp->getLoc()) << "offending nested operation"; @@ -6852,11 +6939,10 @@ LogicalResult localizeAllScheduledBodyCaptures(MaterializerState& state, Materia "final scheduled body capture localization tried to reuse a tensor from another materialized class", "final scheduled body capture localization found an unsupported external non-tensor operand"); if (failed(localized)) { - InFlightDiagnostic diagnostic = targetClass.op->emitError( - "failed to localize final scheduled-body operand"); - diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName() - << "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType() - << " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp)) + InFlightDiagnostic diagnostic = targetClass.op->emitError("failed to localize final scheduled-body operand"); + diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName() << "' operand#" + << operand.getOperandNumber() << " operandType=" << current.getType() << " offendingIR=\"" + << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp)) << "\" offendingOperands=\"" << formatMaterializerOperandListInline(nestedOp, targetClass) << "\" parentChain=\"" << formatMaterializerParentChainInline(nestedOp) << "\""; diagnostic.attachNote(nestedOp->getLoc()) << "offending nested operation"; @@ -7229,34 +7315,34 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange(initValues), - [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc); + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange(initValues), + [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc); - FailureOr> produced = - cloneBatchBodyForLane(state, - targetClass, - run.front().peers.front(), - sourceLane, - group.resultIndices, - CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex}); - if (failed(produced)) - return failure(); + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + run.front().peers.front(), + sourceLane, + group.resultIndices, + CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex}); + if (failed(produced)) + return failure(); - yielded.reserve(produced->size()); - for (auto [outputIndex, output] : llvm::enumerate(*produced)) { - auto fragmentType = cast(output.getType()); - Value acc = iterArgs[outputIndex]; - Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc); - yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset)); - } - return success(); - }); + yielded.reserve(produced->size()); + for (auto [outputIndex, output] : llvm::enumerate(*produced)) { + auto fragmentType = cast(output.getType()); + Value acc = iterArgs[outputIndex]; + Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc); + yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset)); + } + return success(); + }); if (failed(loop)) return failure(); @@ -7408,8 +7494,7 @@ StringRef describeWholeTensorBarrierReason(WholeTensorBarrierReason reason) { switch (reason) { case WholeTensorBarrierReason::FunctionReturnWithoutBlueprint: return "function return or external use without spat.blueprint assembly"; - case WholeTensorBarrierReason::DenseLogicalConsumer: - return "consumer requires a dense logical tensor"; + case WholeTensorBarrierReason::DenseLogicalConsumer: return "consumer requires a dense logical tensor"; } llvm_unreachable("unknown whole-tensor barrier reason"); } @@ -7427,12 +7512,12 @@ FailureOr classifyRunOutputDemand(MaterializerState& state, ArrayRef firstOriginalOutputs = getFirstMaterializationRunOriginalOutputs(state, run); if (resultIndex >= firstOriginalOutputs.size() || resultIndex >= fragmentTypes.size()) return targetClass.op->emitError("compact batch demand classification found an invalid output index") - << " resultIndex=" << resultIndex; + << " resultIndex=" << resultIndex; auto fragmentType = dyn_cast(fragmentTypes[resultIndex]); if (!fragmentType || !fragmentType.hasStaticShape() || fragmentType.getRank() == 0) return targetClass.op->emitError("compact batch demand classification requires static ranked fragment metadata") - << " resultIndex=" << resultIndex << " fragmentType=" << fragmentTypes[resultIndex]; + << " resultIndex=" << resultIndex << " fragmentType=" << fragmentTypes[resultIndex]; RunOutputDemand demand; demand.resultIndex = resultIndex; @@ -7440,46 +7525,43 @@ FailureOr classifyRunOutputDemand(MaterializerState& state, demand.fragmentType = fragmentType; for (ClassId destinationClass : destinationClasses) - demand.actions.push_back(TensorDemandAction { - .kind = TensorDemandActionKind::DestinationFanout, - .destinationClass = destinationClass, - .barrierReason = std::nullopt}); + demand.actions.push_back(TensorDemandAction {.kind = TensorDemandActionKind::DestinationFanout, + .destinationClass = destinationClass, + .barrierReason = std::nullopt}); if (hasMaterializationRunResultSameClassConsumer(state, targetClass.id, run, resultIndex)) - demand.actions.push_back(TensorDemandAction { - .kind = TensorDemandActionKind::SameClassIndexedFragment, - .destinationClass = std::nullopt, - .barrierReason = std::nullopt}); + demand.actions.push_back(TensorDemandAction {.kind = TensorDemandActionKind::SameClassIndexedFragment, + .destinationClass = std::nullopt, + .barrierReason = std::nullopt}); if (!hasMaterializationRunResultLiveExternalUse(state, run, resultIndex)) return demand; Value originalOutput = demand.originalOutput; if (!isTerminalHostBatchOutput(originalOutput, state.oldComputeOps)) { - demand.actions.push_back(TensorDemandAction { - .kind = TensorDemandActionKind::WholeTensorBarrier, - .destinationClass = std::nullopt, - .barrierReason = WholeTensorBarrierReason::FunctionReturnWithoutBlueprint}); + demand.actions.push_back( + TensorDemandAction {.kind = TensorDemandActionKind::WholeTensorBarrier, + .destinationClass = std::nullopt, + .barrierReason = WholeTensorBarrierReason::FunctionReturnWithoutBlueprint}); return demand; } auto outputType = dyn_cast(originalOutput.getType()); if (!outputType || !outputType.hasStaticShape()) return targetClass.op->emitError("failed to classify compact batch output demand") - << " reason=terminal blueprint publication requires static ranked output metadata" - << " producerOp='" << sourceBatch->getName() << "' resultIndex=" << resultIndex - << " sourceClass=" << targetClass.id << " valueType=" << originalOutput.getType(); + << " reason=terminal blueprint publication requires static ranked output metadata" + << " producerOp='" << sourceBatch->getName() << "' resultIndex=" << resultIndex + << " sourceClass=" << targetClass.id << " valueType=" << originalOutput.getType(); if (failed(getBatchResultProjectionInsert(sourceBatch, resultIndex))) return targetClass.op->emitError("failed to classify compact batch output demand") - << " reason=terminal blueprint publication is missing projection metadata" - << " producerOp='" << sourceBatch->getName() << "' resultIndex=" << resultIndex - << " sourceClass=" << targetClass.id << " valueType=" << originalOutput.getType(); + << " reason=terminal blueprint publication is missing projection metadata" + << " producerOp='" << sourceBatch->getName() << "' resultIndex=" << resultIndex + << " sourceClass=" << targetClass.id << " valueType=" << originalOutput.getType(); - demand.actions.push_back(TensorDemandAction { - .kind = TensorDemandActionKind::TerminalBlueprintPublication, - .destinationClass = std::nullopt, - .barrierReason = std::nullopt}); + demand.actions.push_back(TensorDemandAction {.kind = TensorDemandActionKind::TerminalBlueprintPublication, + .destinationClass = std::nullopt, + .barrierReason = std::nullopt}); return demand; } @@ -7489,11 +7571,10 @@ bool hasWholeTensorBarrier(const RunOutputDemand& demand) { }); } -FailureOr> -tryBuildCompactRunPlan(MaterializerState& state, - MaterializedClass& targetClass, - ArrayRef run, - ArrayRef groups) { +FailureOr> tryBuildCompactRunPlan(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run, + ArrayRef groups) { if (run.size() < 2 || run.front().peers.empty()) return std::optional {}; @@ -7570,8 +7651,16 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state, auto rankedFragmentType = cast(fragmentType); Value representativeOriginalOutput = firstOriginalOutputs[resultIndex]; - FailureOr recordedProjectedHostFragments = recordProjectedScalarHostFragmentsFromPackedRun( - state, targetClass, sourceBatch, resultIndex, run, packed, rankedFragmentType, representativeOriginalOutput, loc); + FailureOr recordedProjectedHostFragments = + recordProjectedScalarHostFragmentsFromPackedRun(state, + targetClass, + sourceBatch, + resultIndex, + run, + packed, + rankedFragmentType, + representativeOriginalOutput, + loc); if (failed(recordedProjectedHostFragments)) return failure(); @@ -7618,7 +7707,7 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state, } bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) { - SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, classId}; + SameClassConsumerLookupKey lookupKey {producerKey.instance.op, producerKey.resultIndex, classId}; auto it = state.sameClassConsumerIndex.find(lookupKey); if (it == state.sameClassConsumerIndex.end()) return false; @@ -7862,7 +7951,7 @@ FailureOr> materializeCompactBatchOutputGroupLoop(Material FailureOr packedType = getPackedRunTensorType(output.fragmentType, run.size()); if (failed(packedType)) return sourceBatch.emitOpError("cannot compact batch run for non-static ranked output") - << " resultIndex=" << output.resultIndex; + << " resultIndex=" << output.resultIndex; initValues.push_back( tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult()); } @@ -7901,8 +7990,7 @@ FailureOr> materializeCompactBatchOutputGroupLoop(Material auto fragmentType = dyn_cast(output.getType()); if (!fragmentType || !fragmentType.hasStaticShape()) return failure(); - Value firstOffset = scaleIndexByDim0Size( - state, targetClass.op, slotIndex, fragmentType.getDimSize(0), loc); + Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, slotIndex, fragmentType.getDimSize(0), loc); yielded.push_back(createDim0InsertSlice(state, loc, output, iterArgs[outputIndex], firstOffset)); } return success(); @@ -7966,8 +8054,8 @@ LogicalResult emitPackedBatchRunSends(MaterializerState& state, auto packedIt = packedOutputIndexByResult.find(sendPlan.resultIndex); if (packedIt == packedOutputIndexByResult.end()) return targetClass.op->emitError("missing packed output for compact batch run send plan"); - if (failed(appendBatchRunReceives( - state, targetClass, run, sendPlan, plan.outputs[packedIt->second].fragmentType, loc))) + if (failed( + appendBatchRunReceives(state, targetClass, run, sendPlan, plan.outputs[packedIt->second].fragmentType, loc))) return failure(); } @@ -7987,7 +8075,8 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, auto sourceBatch = cast(run.front().peers.front().op); Location loc = sourceBatch.getLoc(); - FailureOr> packedOutputs = materializeCompactBatchOutputGroupLoop(state, targetClass, run, plan); + FailureOr> packedOutputs = + materializeCompactBatchOutputGroupLoop(state, targetClass, run, plan); if (failed(packedOutputs)) return failure(); @@ -7996,8 +8085,7 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, for (const TensorDemandAction& action : output.actions) { switch (action.kind) { - case TensorDemandActionKind::DestinationFanout: - break; + case TensorDemandActionKind::DestinationFanout: break; case TensorDemandActionKind::SameClassIndexedFragment: if (failed(recordLocalIndexedBatchRunValue( state, targetClass, run, output.resultIndex, (*packedOutputs)[packedIndex], output.fragmentType))) @@ -8009,14 +8097,15 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, if (failed(recordedProjectedHostFragments)) return failure(); if (!*recordedProjectedHostFragments) - return sourceBatch.emitOpError("compact batch blueprint publication requires explicit fragment assembly metadata") - << " resultIndex=" << output.resultIndex; + return sourceBatch.emitOpError( + "compact batch blueprint publication requires explicit fragment assembly metadata") + << " resultIndex=" << output.resultIndex; break; } case TensorDemandActionKind::WholeTensorBarrier: return sourceBatch.emitOpError("compact batch materialization reached a whole-tensor barrier unexpectedly") - << " resultIndex=" << output.resultIndex << " reason=" - << describeWholeTensorBarrierReason(*action.barrierReason); + << " resultIndex=" << output.resultIndex + << " reason=" << describeWholeTensorBarrierReason(*action.barrierReason); } } } @@ -8027,8 +8116,7 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, return success(); } -LogicalResult materializeInstanceSlot(MaterializerState& state, - const ComputeInstance& instance) { +LogicalResult materializeInstanceSlot(MaterializerState& state, const ComputeInstance& instance) { auto cpuIt = state.schedule.computeToCpuMap.find(instance); if (cpuIt == state.schedule.computeToCpuMap.end()) return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); @@ -8049,7 +8137,8 @@ LogicalResult materializeInstanceSlot(MaterializerState& state, return success(); if (isa(instance.op)) { - FailureOr run = collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op); + FailureOr run = + collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op); if (succeeded(run)) { if (!targetClass.isBatch) @@ -8160,7 +8249,7 @@ bool valueMayEvaluateToCore(Value value, int64_t coreId) { return false; for (int64_t iteration = *lower; iteration < *upper; iteration += *step) { - FailureOr evaluated = evaluateSingleResultAffineMap(map, ArrayRef{iteration}); + FailureOr evaluated = evaluateSingleResultAffineMap(map, ArrayRef {iteration}); if (succeeded(evaluated) && *evaluated == coreId) return true; }