diff --git a/.serena/project.yml b/.serena/project.yml new file mode 100644 index 0000000..06c4bd3 --- /dev/null +++ b/.serena/project.yml @@ -0,0 +1,134 @@ +# the name by which the project can be referenced within Serena +project_name: raptor + +# list of languages for which language servers are started; choose from: +# al angular ansible bash clojure +# cpp cpp_ccls crystal csharp csharp_omnisharp +# dart elixir elm erlang fortran +# fsharp go groovy haskell haxe +# hlsl html java json julia +# kotlin lean4 lua luau markdown +# matlab msl nix ocaml pascal +# perl php php_phpactor powershell python +# python_jedi python_ty r rego ruby +# ruby_solargraph rust scala scss solidity +# svelte swift systemverilog terraform toml +# typescript typescript_vts vue yaml zig +# (This list may be outdated. For the current list, see values of Language enum here: +# https://github.com/oraios/serena/blob/main/src/solidlsp/ls_config.py +# For some languages, there are alternative language servers, e.g. csharp_omnisharp, ruby_solargraph.) +# Note: +# - For C, use cpp +# - For JavaScript, use typescript +# - For Angular projects, use angular (subsumes typescript+html; requires `npm install` in the project root) +# - For Svelte projects, use svelte (subsumes typescript/javascript for .svelte projects; requires npm) +# - For SCSS / Sass / plain CSS, use scss (some-sass-language-server handles all three) +# - For Free Pascal/Lazarus, use pascal +# Special requirements: +# Some languages require additional setup/installations. +# See here for details: https://oraios.github.io/serena/01-about/020_programming-languages.html#language-servers +# When using multiple languages, the first language server that supports a given file will be used for that file. +# The first language is the default language and the respective language server will be used as a fallback. +# Note that when using the JetBrains backend, language servers are not used and this list is correspondingly ignored. +languages: +- cpp +- rust +- python + +# the encoding used by text files in the project +# For a list of possible encodings, see https://docs.python.org/3.11/library/codecs.html#standard-encodings +encoding: utf-8 + +# list of additional paths to ignore in this project. +# Same syntax as gitignore, so you can use * and **. +# Note: global ignored_paths from serena_config.yml are also applied additively. +ignored_paths: + +# list of mode names that are to be activated by default, overriding the setting in the global configuration. +# The full set of modes to be activated is base_modes (from global config) + default_modes + added_modes. +# If the setting is undefined/empty, the default_modes from the global configuration (serena_config.yml) apply. +# Otherwise, this overrides the setting from the global configuration (serena_config.yml). +# Therefore, you can set this to [] if you do not want the default modes defined in the global config to apply +# for this project. +# This setting can, in turn, be overridden by CLI parameters (--mode). +# See https://oraios.github.io/serena/02-usage/050_configuration.html#modes +default_modes: + +# list of mode names to be activated additionally for this project, e.g. ["query-projects"] +# The full set of modes to be activated is base_modes (from global config) + default_modes + added_modes. +# See https://oraios.github.io/serena/02-usage/050_configuration.html#modes +added_modes: + +# list of tool names to exclude. +# This extends the existing exclusions (e.g. from the global configuration) +# Find the list of tools here: https://oraios.github.io/serena/01-about/035_tools.html +excluded_tools: [] + +# list of tools to include that would otherwise be disabled (particularly optional tools that are disabled by default). +# This extends the existing inclusions (e.g. from the global configuration). +# Find the list of tools here: https://oraios.github.io/serena/01-about/035_tools.html +included_optional_tools: [] + +# fixed set of tools to use as the base tool set (if non-empty), replacing Serena's default set of tools. +# This cannot be combined with non-empty excluded_tools or included_optional_tools. +# Find the list of tools here: https://oraios.github.io/serena/01-about/035_tools.html +fixed_tools: [] + +# time budget (seconds) per tool call for the retrieval of additional symbol information +# such as docstrings or parameter information. +# This overrides the corresponding setting in the global configuration; see the documentation there. +# If null or missing, use the setting from the global configuration. +symbol_info_budget: + +# The language backend to use for this project. +# If not set, the global setting from serena_config.yml is used. +# Valid values: LSP, JetBrains +# Note: the backend is fixed at startup. If a project with a different backend +# is activated post-init, an error will be returned. +language_backend: + +# line ending convention to use when writing source files. +# Possible values: unset (use global setting), "lf", "crlf", or "native" (platform default) +# This does not affect Serena's own files (e.g. memories and configuration files), which always use native line endings. +line_ending: + +# list of regex patterns which, when matched, mark a memory entry as read‑only. +# Extends the list from the global configuration, merging the two lists. +read_only_memory_patterns: [] + +# list of regex patterns for memories to completely ignore. +# Matching memories will not appear in list_memories or activate_project output +# and cannot be accessed via read_memory or write_memory. +# To access ignored memory files, use the read_file tool on the raw file path. +# Extends the list from the global configuration, merging the two lists. +# Example: ["_archive/.*", "_episodes/.*"] +ignored_memory_patterns: [] + +# advanced configuration option allowing to configure language server-specific options. +# Maps the language key to the options. +# Have a look at the docstring of the constructors of the LS implementations within solidlsp (e.g., for C# or PHP) to see which options are available. +# No documentation on options means no options are available. +ls_specific_settings: {} + +# list of additional workspace folder paths for cross-package reference support (e.g. in monorepos). +# Paths can be absolute or relative to the project root. +# Each folder is registered as an LSP workspace folder, enabling language servers to discover +# symbols and references across package boundaries. +# Currently supported for: TypeScript. +# Example: +# additional_workspace_folders: +# - ../sibling-package +# - ../shared-lib +additional_workspace_folders: [] + +# whether the project is in read-only mode +# If set to true, all editing tools will be disabled and attempts to use them will result in an error +# Added on 2025-04-18 +read_only: false + +# whether to use project's .gitignore files to ignore files +ignore_all_files_in_gitignore: true + +# initial prompt for the project. It will always be given to the LLM upon activating the project +# (contrary to the memories, which are loaded on demand). +initial_prompt: '' diff --git a/src/PIM/Conversion/ONNXToSpatial/SpatialLayoutPlanningPass.cpp b/src/PIM/Conversion/ONNXToSpatial/SpatialLayoutPlanningPass.cpp index d4fd626..9c300dc 100644 --- a/src/PIM/Conversion/ONNXToSpatial/SpatialLayoutPlanningPass.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/SpatialLayoutPlanningPass.cpp @@ -97,11 +97,17 @@ static spatial::SpatReconciliatorOp insertRowStripReconciliator(IRRewriter& rewr value.getLoc(), outputType, value, + ValueRange {}, rewriter.getStringAttr(kLogicalLayout), rewriter.getStringAttr(kRowStripLayout), rewriter.getDenseI64ArrayAttr(offsets), rewriter.getDenseI64ArrayAttr(sizes), - rewriter.getStringAttr(kRowStripIndexMap)); + rewriter.getStringAttr(kRowStripIndexMap), + nullptr, + nullptr, + nullptr, + nullptr, + nullptr); } static void materializeDenseUses(IRRewriter& rewriter, diff --git a/src/PIM/Dialect/Spatial/Spatial.td b/src/PIM/Dialect/Spatial/Spatial.td index 6e1d701..1b7e60d 100644 --- a/src/PIM/Dialect/Spatial/Spatial.td +++ b/src/PIM/Dialect/Spatial/Spatial.td @@ -233,15 +233,21 @@ def SpatReluPlanOp : SpatOp<"relu_plan", []> { } def SpatReconciliatorOp : SpatOp<"reconciliator", []> { - let summary = "Passive logical-to-physical layout selection record"; + let summary = "Logical-to-physical layout record or explicit fragment assembly"; let arguments = (ins SpatTensor:$input, + Variadic:$fragments, StrAttr:$logicalLayout, StrAttr:$physicalLayout, DenseI64ArrayAttr:$fragmentOffsets, DenseI64ArrayAttr:$fragmentSizes, - StrAttr:$indexMap + StrAttr:$indexMap, + OptionalAttr:$mode, + OptionalAttr:$fragmentOperandIndices, + OptionalAttr:$fragmentStrides, + OptionalAttr:$conflictPolicy, + OptionalAttr:$coveragePolicy ); let results = (outs diff --git a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp index f4420c2..7ea4dde 100644 --- a/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOpsVerify.cpp @@ -383,7 +383,7 @@ LogicalResult SpatConcatOp::verify() { static bool isKnownLogicalLayout(StringRef layout) { return layout == "nchw"; } static bool isKnownPhysicalLayout(StringRef layout) { - return layout == "dense_nchw" || layout == "nchw_row_strip"; + return layout == "dense_nchw" || layout == "nchw_row_strip" || layout == "fragmented"; } static LogicalResult verifyPlanTensorTypes(Operation* op, Value input, Value output, StringRef kind) { @@ -437,7 +437,9 @@ LogicalResult SpatReluPlanOp::verify() { } LogicalResult SpatReconciliatorOp::verify() { - if (failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.reconciliator"))) + auto modeAttr = getModeAttr(); + bool isFragmentAssembly = modeAttr && modeAttr.getValue() == "fragment_assembly"; + if (!isFragmentAssembly && failed(verifyPlanTensorTypes(getOperation(), getInput(), getOutput(), "spat.reconciliator"))) return failure(); if (!isKnownLogicalLayout(getLogicalLayout())) return emitError("requires a known logical layout"); @@ -452,23 +454,154 @@ LogicalResult SpatReconciliatorOp::verify() { auto sizes = getFragmentSizes(); if (offsets.size() != sizes.size()) return emitError("fragment offset and size arrays must have the same length"); + int64_t rank = logicalType.getRank(); if (offsets.empty()) return success(); - - int64_t rank = logicalType.getRank(); if (rank <= 0 || offsets.size() % rank != 0) return emitError("fragment metadata must be a whole number of rank-sized fragments"); - ArrayRef shape = logicalType.getShape(); - for (int64_t index = 0; index < static_cast(offsets.size()); ++index) { - int64_t dim = index % rank; - int64_t offset = offsets[index]; - int64_t size = sizes[index]; - if (offset < 0 || size < 0) - return emitError("fragment offsets and sizes must be non-negative"); - int64_t logicalDim = shape[dim]; - if (!ShapedType::isDynamic(logicalDim) && offset + size > logicalDim) - return emitError("fragment bounds must stay within the logical tensor shape"); + auto verifyBoundsOnly = [&](ArrayRef strideValues) -> LogicalResult { + ArrayRef shape = logicalType.getShape(); + for (int64_t index = 0; index < static_cast(offsets.size()); ++index) { + int64_t dim = index % rank; + int64_t offset = offsets[index]; + int64_t size = sizes[index]; + int64_t stride = strideValues.empty() ? 1 : strideValues[index]; + if (offset < 0 || size < 0 || stride < 0) + return emitError("fragment offsets, sizes, and strides must be non-negative"); + int64_t logicalDim = shape[dim]; + if (!ShapedType::isDynamic(logicalDim) && offset + size > logicalDim) + return emitError("fragment bounds must stay within the logical tensor shape"); + if (stride != 1) + return emitError("fragment assembly currently requires unit strides"); + } + return success(); + }; + + if (!isFragmentAssembly) { + if (failed(verifyBoundsOnly({}))) + return failure(); + if (!getFragments().empty()) + return emitError("legacy reconciliator does not accept extra fragment operands"); + if (getFragmentStridesAttr() || getConflictPolicyAttr() || getCoveragePolicyAttr()) + return emitError("legacy reconciliator does not accept fragment assembly attributes"); + return success(); + } + + auto stridesAttr = getFragmentStridesAttr(); + auto operandIndicesAttr = getFragmentOperandIndicesAttr(); + if (!operandIndicesAttr) + return emitError("fragment assembly reconciliator requires fragment operand indices"); + if (!stridesAttr) + return emitError("fragment assembly reconciliator requires fragment strides"); + ArrayRef operandIndices = operandIndicesAttr.asArrayRef(); + ArrayRef strides = stridesAttr.asArrayRef(); + if (strides.size() != offsets.size()) + return emitError("fragment stride and offset arrays must have the same length"); + if (!getConflictPolicyAttr() || !getCoveragePolicyAttr()) + return emitError("fragment assembly reconciliator requires conflict and coverage policies"); + if (getConflictPolicy() != "disjoint") + return emitError("fragment assembly reconciliator currently supports only conflict_policy=\"disjoint\""); + if (getCoveragePolicy() != "complete" && getCoveragePolicy() != "partial") + return emitError("fragment assembly reconciliator coverage_policy must be \"complete\" or \"partial\""); + + SmallVector operands; + operands.push_back(getInput()); + llvm::append_range(operands, getFragments()); + int64_t operandCount = static_cast(operands.size()); + int64_t fragmentCount = static_cast(operandIndices.size()); + if (operandCount == 0) + return emitError("fragment assembly reconciliator requires at least one operand"); + if (static_cast(offsets.size()) != fragmentCount * rank) + return emitError("fragment assembly metadata count must match operand count * result rank"); + if (failed(verifyBoundsOnly(strides))) + return failure(); + + SmallVector, SmallVector>, 8> slices; + slices.reserve(static_cast(fragmentCount)); + SmallVector, 4>, 8> sizesByOperand(static_cast(operandCount)); + for (int64_t fragmentIndex = 0; fragmentIndex < fragmentCount; ++fragmentIndex) { + int64_t operandIndex = operandIndices[fragmentIndex]; + if (operandIndex < 0 || operandIndex >= operandCount) + return emitError("fragment assembly operand index is out of range"); + + auto operandType = dyn_cast(operands[operandIndex].getType()); + if (!operandType || !operandType.hasStaticShape()) + return emitError("fragment assembly reconciliator requires static ranked tensor operands"); + if (operandType.getRank() != rank) + return emitError("fragment assembly reconciliator requires operand/result rank match"); + + SmallVector fragmentOffsets; + SmallVector fragmentSizes; + fragmentOffsets.reserve(rank); + fragmentSizes.reserve(rank); + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t flatIndex = fragmentIndex * rank + dim; + fragmentOffsets.push_back(offsets[flatIndex]); + fragmentSizes.push_back(sizes[flatIndex]); + } + + sizesByOperand[static_cast(operandIndex)].push_back(fragmentSizes); + + for (const auto& [existingOffsets, existingSizes] : slices) { + bool overlaps = true; + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t begin = fragmentOffsets[dim]; + int64_t end = begin + fragmentSizes[dim]; + int64_t existingBegin = existingOffsets[dim]; + int64_t existingEnd = existingBegin + existingSizes[dim]; + if (end <= existingBegin || existingEnd <= begin) { + overlaps = false; + break; + } + } + if (overlaps) + return emitError("fragment assembly reconciliator requires disjoint static slices"); + } + slices.push_back({std::move(fragmentOffsets), std::move(fragmentSizes)}); + } + + for (int64_t operandIndex = 0; operandIndex < operandCount; ++operandIndex) { + if (sizesByOperand[static_cast(operandIndex)].empty()) + return emitError("fragment assembly reconciliator requires every operand to contribute at least one fragment"); + + auto operandType = cast(operands[operandIndex].getType()); + ArrayRef operandShape = operandType.getShape(); + auto& fragmentShapes = sizesByOperand[static_cast(operandIndex)]; + if (fragmentShapes.size() == 1) { + if (!llvm::equal(operandShape, fragmentShapes.front())) + return emitError("single-fragment reconciliator operand shape must match declared fragment size"); + continue; + } + + ArrayRef fragmentShape = fragmentShapes.front(); + for (ArrayRef otherShape : fragmentShapes) + if (!llvm::equal(fragmentShape, otherShape)) + return emitError("packed reconciliator operand requires equal fragment sizes per operand"); + if (llvm::equal(operandShape, fragmentShape)) + continue; + if (!llvm::equal(operandShape.drop_front(), fragmentShape.drop_front())) + return emitError("packed reconciliator operand must match fragment shape on non-packed dimensions"); + if (operandShape.front() != static_cast(fragmentShapes.size()) * fragmentShape.front()) + return emitError("packed reconciliator operand first dimension must equal fragment_count * fragment_size"); + } + + if (getCoveragePolicy() == "complete") { + int64_t covered = 0; + int64_t logicalElements = 1; + for (int64_t dimSize : logicalType.getShape()) { + if (ShapedType::isDynamic(dimSize)) + return emitError("fragment assembly complete coverage requires static result shape"); + logicalElements *= dimSize; + } + for (const auto& [ignoredOffsets, fragmentSizes] : slices) { + int64_t fragmentElements = 1; + for (int64_t dimSize : fragmentSizes) + fragmentElements *= dimSize; + covered += fragmentElements; + } + if (covered != logicalElements) + return emitError("fragment assembly complete coverage must cover the whole result exactly"); } return success(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 82dbf03..21089bb 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -17,6 +17,7 @@ #include "llvm/Support/raw_ostream.h" #include +#include #include #include #include @@ -24,7 +25,6 @@ #include "MaterializeMergeSchedule.hpp" #include "Scheduling/ComputeInstanceUtils.hpp" -#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" #include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" @@ -303,6 +303,20 @@ struct ProjectedExtractReplacement { ProjectedFragmentLayout layout; }; +struct PendingProjectedHostOutputFragment { + Value originalOutput; + ClassId sourceClass = 0; + Value operand; + RankedTensorType operandType; + RankedTensorType fragmentType; + int64_t packedFragmentIndex = -1; + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + uint32_t sourceLane = 0; + Location loc; +}; + struct CloneIndexingContext { std::optional runSlotIndex; std::optional projectionSlotIndex; @@ -325,37 +339,12 @@ struct AffineProjectedInputSliceMatch { }; struct MaterializerState; - -struct PendingProjectedHostReceiveGroup { - Value originalOutput; - ClassId ownerClassId = 0; - RankedTensorType fragmentType; - SmallVector keys; - MessageVector messages; - Location loc; -}; - -struct PendingScalarReceiveRecord { - PendingScalarReceiveRecord(ArrayRef keys, - ClassId targetClassId, - Type receiveType, - const MessageVector& messages, - Location loc) - : targetClassId(targetClassId), - receiveType(receiveType), - messages(messages), - loc(loc) { - this->keys.append(keys.begin(), keys.end()); - } - - SmallVector keys; - ClassId targetClassId = 0; - Type receiveType; - MessageVector messages; - Location loc; - bool materialized = false; - Value value; -}; +FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + Value packed, + Value originalOutput, + Location loc); FailureOr materializeProjectedExtractReplacement(MaterializerState& state, MaterializedClass& targetClass, @@ -387,13 +376,14 @@ LogicalResult localizeCapturesInClonedOp(MaterializerState& state, MaterializedClass& targetClass, Operation& clonedOp, IRMapping* mapper = nullptr); -bool requiresConstantProjectionSlotIndex(MaterializerState& state, - MaterializedClass& targetClass, - Operation* sourceOp); +LogicalResult localizeAllScheduledBodyCaptures(MaterializerState& state, MaterializedClass& targetClass); bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch, const AffineProjectedInputSliceMatch& match, ProducerKey producer, uint32_t consumerLane); +std::optional getProjectedInputSliceMatch(MaterializerState& state, + SpatComputeBatch batch, + unsigned inputIndex); class AvailableValueStore { public: @@ -488,11 +478,7 @@ struct MaterializerState { AvailableValueStore availableValues; DenseMap hostReplacements; DenseMap hostOutputOwners; - SmallVector pendingProjectedHostReceives; - SmallVector pendingScalarReceives; - DenseMap, ProducerKeyInfo> pendingScalarReceiveLookup; - DenseMap firstLateCommunicationOps; - int64_t nextCommunicationTraceId = 0; + SmallVector pendingProjectedHostOutputFragments; DenseSet oldComputeOps; MaterializerState(func::FuncOp func, @@ -608,32 +594,6 @@ std::optional getContiguousProducerRangeForKeys(ArrayRef getPhysicallyContiguousProducerRangeForKeys(ArrayRef keys) { - if (keys.empty()) - return std::nullopt; - - ProducerKey first = keys.front(); - auto batch = dyn_cast_or_null(first.instance.op); - if (!batch || first.instance.laneCount == 0) - return std::nullopt; - - uint32_t laneStart = first.instance.laneStart; - uint32_t nextLane = laneStart; - for (ProducerKey key : keys) { - if (key.instance.op != first.instance.op || key.resultIndex != first.resultIndex || key.instance.laneCount == 0) - return std::nullopt; - if (key.instance.laneStart != nextLane) - return std::nullopt; - nextLane += key.instance.laneCount; - } - - uint32_t laneCount = nextLane - laneStart; - if (laneStart + laneCount > static_cast(batch.getLaneCount())) - return std::nullopt; - - return getBatchLaneProducerKey(batch, laneStart, laneCount, first.resultIndex); -} - WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(Operation* sourceOp, size_t resultIndex, ClassId classId) { return {sourceOp, resultIndex, classId}; } @@ -689,6 +649,11 @@ collectProducerKeysForDestinations(Value value, std::optional l return keys; } + if (logicalConsumer && isa(logicalConsumer->op)) { + keys.push_back(getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber())); + return keys; + } + return {}; } @@ -715,6 +680,11 @@ collectProducerKeysForDestinations(Value value, std::optional l return {}; if (batch.getNumResults() != 0) { + if (logicalConsumer && isa(logicalConsumer->op)) { + keys.push_back(getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber())); + return keys; + } + for (uint32_t lane = 0; lane < static_cast(batch.getLaneCount()); ++lane) keys.push_back(getBatchLaneProducerKey(batch, lane, 1, result.getResultNumber())); return keys; @@ -747,6 +717,9 @@ std::optional getInputRequestProducerKey(Value value, if (std::optional lane = getConstantFirstSliceOffset(extract)) return getBatchLaneProducerKey(batch, *lane, 1, result.getResultNumber()); + if (logicalConsumer && isa(logicalConsumer->op)) + return getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber()); + return std::nullopt; } @@ -771,8 +744,11 @@ std::optional getInputRequestProducerKey(Value value, if (!result) return std::nullopt; - if (batch.getNumResults() != 0) + if (batch.getNumResults() != 0) { + if (logicalConsumer && isa(logicalConsumer->op)) + return getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber()); return getWholeBatchProducerKey(batch, result.getResultNumber()); + } return ProducerKey {getBatchChunkForLane(batch, result.getResultNumber()), 0}; } @@ -780,6 +756,60 @@ std::optional getInputRequestProducerKey(Value value, return std::nullopt; } +std::optional getWholeBatchProducerKeyForDirectBatchResult(Value value) { + auto result = dyn_cast(value); + if (!result) + return std::nullopt; + + auto batch = dyn_cast_or_null(result.getOwner()); + if (!batch || batch.getNumResults() == 0) + return std::nullopt; + + return getWholeBatchProducerKey(batch, result.getResultNumber()); +} + +bool canUseProjectedLaneInput(MaterializerState& state, + SpatComputeBatch consumerBatch, + unsigned inputIndex, + Value input, + ComputeInstance logicalConsumer) { + auto producerResult = dyn_cast(input); + if (!producerResult) + return false; + + auto producerBatch = dyn_cast_or_null(producerResult.getOwner()); + if (!producerBatch || producerBatch.getNumResults() == 0) + return false; + + std::optional match = + getProjectedInputSliceMatch(state, consumerBatch, inputIndex); + if (!match) + return false; + + ProducerKey laneProducer = + getBatchLaneProducerKey(producerBatch, logicalConsumer.laneStart, 1, producerResult.getResultNumber()); + return isProjectedInputSliceCompatibleWithProducerFragments( + consumerBatch, *match, laneProducer, logicalConsumer.laneStart); +} + +SmallVector collectProducerKeysForBatchInputDestinations(MaterializerState& state, + SpatComputeBatch consumerBatch, + unsigned inputIndex, + Value input, + ComputeInstance logicalConsumer) { + if (std::optional wholeBatchProducer = getWholeBatchProducerKeyForDirectBatchResult(input)) { + if (!canUseProjectedLaneInput(state, consumerBatch, inputIndex, input, logicalConsumer)) { + auto producerBatch = cast(wholeBatchProducer->instance.op); + SmallVector keys; + for (uint32_t lane = 0; lane < static_cast(producerBatch.getLaneCount()); ++lane) + keys.push_back(getBatchLaneProducerKey(producerBatch, lane, 1, wholeBatchProducer->resultIndex)); + return keys; + } + } + + return collectProducerKeysForDestinations(input, logicalConsumer); +} + class CpuUnionFind { public: void insert(CpuId cpu) { parent.try_emplace(cpu, cpu); } @@ -969,7 +999,6 @@ LogicalResult collectHostOutputs(MaterializerState& state) { DenseSet seenOutputs; SmallVector orderedOutputs; DenseMap preferredOwners; - for (const ComputeInstance& instance : state.schedule.dominanceOrderCompute) { auto cpuIt = state.schedule.computeToCpuMap.find(instance); if (cpuIt == state.schedule.computeToCpuMap.end()) @@ -994,17 +1023,17 @@ LogicalResult collectHostOutputs(MaterializerState& state) { ClassId currentOwner = preferredOwners.lookup(output); bool terminalHost = isTerminalHostBatchOutput(output, state.oldComputeOps); if (terminalHost) { - // Terminal resultful batch outputs are still published through scalar - // host-output slots unless the materialized batch class owns the output - // directly. Selecting an arbitrary batch class as the host owner would - // require a projection-aware batch publication path, which the - // materializer does not currently implement. - if (state.classes[currentOwner].isBatch && !materializedClass.isBatch) + // Terminal batch outputs should stay owned by the producing batch class + // so publication remains explicit in IR via the batch host-output path. + if (!state.classes[currentOwner].isBatch && materializedClass.isBatch) preferredOwners[output] = classId; continue; } - if (state.classes[currentOwner].isBatch && !materializedClass.isBatch) + // Keep batch-defined outputs on a batch owner whenever one exists so + // publication and reconstruction remain explicit in Spatial IR instead of + // falling back to late scalar host forwarding. + if (!state.classes[currentOwner].isBatch && materializedClass.isBatch) preferredOwners[output] = classId; } } @@ -1081,7 +1110,8 @@ LogicalResult createEmptyMaterializedOps(MaterializerState& state) { state.rewriter, state.func, materializedClass.cpus.size(), "materialized batch lane count"); if (failed(batchLaneCountAttr)) return failure(); - auto batch = SpatScheduledComputeBatch::create(state.rewriter, loc, TypeRange(resultTypes), *batchLaneCountAttr, ValueRange {}, ValueRange {}); + auto batch = SpatScheduledComputeBatch::create( + state.rewriter, loc, TypeRange(resultTypes), *batchLaneCountAttr, ValueRange {}, ValueRange {}); batch.getProperties().setOperandSegmentSizes({0, 0}); auto coreIds = getCheckedCoreIds(state.func, materializedClass.cpus, "materialized batch core id"); if (failed(coreIds)) @@ -1150,6 +1180,10 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali llvm_unreachable("Cannot reach here"); } +// ----------------------------------------------------------------------------- +// Materialized-class value localization helpers. +// ----------------------------------------------------------------------------- + Region* getParentRegion(Value value) { if (auto blockArg = dyn_cast(value)) return blockArg.getOwner()->getParent(); @@ -1979,11 +2013,7 @@ std::optional extractPackedProducerSlice(MaterializerState& state, state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); Value firstOffset = getOrCreateIndexConstant(state.constantFolder, materializedClass.op, rowOffset); - FailureOr slice = - createDim0ExtractSliceInClass(state, materializedClass, materializedClass.op->getLoc(), packed, firstOffset, rowCount); - if (failed(slice)) - return std::nullopt; - return *slice; + return createDim0ExtractSlice(state, materializedClass.op->getLoc(), packed, firstOffset, rowCount); } std::optional AvailableValueStore::lookupExact(ProducerKey key, ClassId classId) const { @@ -1998,15 +2028,15 @@ std::optional AvailableValueStore::lookupExact(ProducerKey key, ClassId c return valueIt->second; } -FailureOr getPackedSliceForRunIndex(MaterializerState& state, - MaterializedClass& targetClass, - Value packed, - RankedTensorType fragmentType, - size_t index, - Location loc) { +Value getPackedSliceForRunIndex(MaterializerState& state, + Operation* anchor, + Value packed, + RankedTensorType fragmentType, + size_t index, + Location loc) { int64_t rowOffset = static_cast(index) * fragmentType.getDimSize(0); - Value firstOffset = getOrCreateIndexConstant(state.constantFolder, targetClass.op, rowOffset); - return createDim0ExtractSliceInClass(state, targetClass, loc, packed, firstOffset, fragmentType.getDimSize(0)); + Value firstOffset = getOrCreateIndexConstant(state.constantFolder, anchor, rowOffset); + return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0)); } FailureOr createReceiveConcatLoop(MaterializerState& state, @@ -2024,8 +2054,6 @@ FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& PackedScalarRunValue& run, Location loc); -SmallVector flattenPackedScalarRunKeys(const PackedScalarRunValue& run); - bool isDeferredLocalPackedScalarRun(const PackedScalarRunValue& run) { return run.kind == PackedScalarRunKind::DeferredLocalCompute; } @@ -2088,45 +2116,12 @@ std::optional AvailableValueStore::lookupPackedRun(MaterializerState& sta if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) continue; - size_t flattenedIndexBase = 0; for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) { - std::optional contiguousKey = getPhysicallyContiguousProducerRangeForKeys(slot.keys); - if (contiguousKey && containsProducerKey(*contiguousKey, key)) { - FailureOr slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size()); - if (failed(slotPackedType)) - return std::nullopt; - - MaterializedClass& materializedClass = state.classes[classId]; - state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); - - FailureOr packed = - materializePackedScalarRunValue(state, materializedClass, run, materializedClass.op->getLoc()); - if (failed(packed)) - return std::nullopt; - FailureOr slotPacked = - getPackedSliceForRunIndex(state, materializedClass, *packed, *slotPackedType, slotIndex, (*packed).getLoc()); - if (failed(slotPacked)) - return std::nullopt; - - if (*contiguousKey == key) { - record(key, classId, *slotPacked); - return *slotPacked; - } - - std::optional sliced = - extractPackedProducerSlice(state, materializedClass, *contiguousKey, *slotPacked, key); - if (!sliced) - return std::nullopt; - - record(key, classId, *sliced); - return *sliced; - } - + std::optional contiguousKey = getContiguousProducerRangeForKeys(slot.keys); auto keyIt = llvm::find(slot.keys, key); - if (keyIt == slot.keys.end()) { - flattenedIndexBase += slot.keys.size(); + if ((!contiguousKey && keyIt == slot.keys.end()) + || (contiguousKey && !containsProducerKey(*contiguousKey, key))) continue; - } MaterializedClass& materializedClass = state.classes[classId]; state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); @@ -2135,11 +2130,35 @@ std::optional AvailableValueStore::lookupPackedRun(MaterializerState& sta materializePackedScalarRunValue(state, materializedClass, run, materializedClass.op->getLoc()); if (failed(packed)) return std::nullopt; - size_t flattenedIndex = flattenedIndexBase + static_cast(std::distance(slot.keys.begin(), keyIt)); - FailureOr sliced = - getPackedSliceForRunIndex(state, materializedClass, *packed, run.fragmentType, flattenedIndex, (*packed).getLoc()); - if (failed(sliced)) + + if (!contiguousKey) { + Value sliced = getPackedSliceForRunIndex(state, + materializedClass.op, + *packed, + run.fragmentType, + static_cast(std::distance(slot.keys.begin(), keyIt)), + (*packed).getLoc()); + record(key, classId, sliced); + return sliced; + } + + FailureOr slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size()); + if (failed(slotPackedType)) return std::nullopt; + + Value slotPacked = + getPackedSliceForRunIndex(state, materializedClass.op, *packed, *slotPackedType, slotIndex, (*packed).getLoc()); + + if (*contiguousKey == key) { + record(key, classId, slotPacked); + return slotPacked; + } + + std::optional sliced = + extractPackedProducerSlice(state, materializedClass, *contiguousKey, slotPacked, key); + if (!sliced) + return std::nullopt; + record(key, classId, *sliced); return *sliced; } @@ -2179,6 +2198,7 @@ std::optional AvailableValueStore::lookup(MaterializerState& state, Produ auto valueIt = classValues.find(classId); if (valueIt == classValues.end()) continue; + std::optional slice = extractPackedProducerSlice(state, materializedClass, candidateKey, valueIt->second, key); if (!slice) @@ -2418,72 +2438,6 @@ Value createLaneIndexedIndexValue(MaterializerState& state, return createLaneIndexedIndexValue(state, materializedClass, ArrayRef(widened), loc); } -FailureOr remapProjectionIndexLike(MaterializerState& state, - Operation* anchor, - OpFoldResult value, - Value sourceLaneArg, - Value mappedLaneValue, - Location loc) { - if (auto attr = dyn_cast(value)) - return value; - - Value operand = cast(value); - if (operand == sourceLaneArg) - return OpFoldResult(mappedLaneValue); - - if (matchPattern(operand, m_Constant())) - return getAsOpFoldResult(operand); - - auto affineApply = operand.getDefiningOp(); - if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) - return failure(); - - SmallVector remappedOperands; - remappedOperands.reserve(affineApply.getMapOperands().size()); - for (Value mapOperand : affineApply.getMapOperands()) { - FailureOr remapped = - remapProjectionIndexLike(state, anchor, OpFoldResult(mapOperand), sourceLaneArg, mappedLaneValue, loc); - if (failed(remapped)) - return failure(); - remappedOperands.push_back(getValueOrCreateConstantIndexOp(state.rewriter, loc, *remapped)); - } - - return getAsOpFoldResult( - createOrFoldAffineApply(state.rewriter, loc, affineApply.getAffineMap(), remappedOperands, state.func)); -} - -FailureOr createProjectionLaneValueForKeys(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Location loc) { - if (!sourceClass.isBatch) - return sourceClass.op->emitError("projection lane mapping expects a batch materialized class"); - - auto batch = cast(sourceClass.op); - auto laneArg = batch.getLaneArgument(); - if (!laneArg) - return batch.emitOpError("missing lane argument for projected batch host publication"); - - if (keys.size() == 1) { - if (keys.front().instance.laneCount != 1) - return batch.emitOpError("projected batch host publication expects one logical lane per fragment"); - return getOrCreateIndexConstant(state.constantFolder, sourceClass.op, keys.front().instance.laneStart); - } - - if (keys.size() != sourceClass.cpus.size()) - return batch.emitOpError("projected batch host publication expected one producer key per materialized batch lane"); - - SmallVector sourceLanes; - sourceLanes.reserve(keys.size()); - for (ProducerKey key : keys) { - if (key.instance.laneCount != 1) - return batch.emitOpError("projected batch host publication expects one logical lane per fragment"); - sourceLanes.push_back(key.instance.laneStart); - } - - return createIndexedIndexValue(state, sourceClass.op, sourceLanes, *laneArg, loc, std::nullopt, true); -} - FailureOr> getPeerLogicalInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId logicalSlot) { SmallVector peers; @@ -2581,59 +2535,6 @@ bool isTerminalHostBatchOutput(Value output, const DenseSet& oldComp return !hasRealComputeConsumer(output, oldComputeOps); } -bool isProjectedTerminalBatchHostOutput(Value output, const DenseSet& oldComputeOps) { - if (!isTerminalHostBatchOutput(output, oldComputeOps)) - return false; - - auto batch = dyn_cast_or_null(output.getDefiningOp()); - auto originalResult = dyn_cast(output); - if (!batch || !originalResult) - return false; - - FailureOr projection = - getBatchResultProjectionInsert(batch, originalResult.getResultNumber()); - if (failed(projection)) - return false; - - return projection->getSource().getType() != output.getType(); -} - -LogicalResult emitBatchToScalarDestinationDiagnostic(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Value originalOutput) { - auto diag = sourceClass.op->emitError("resultful compute_batch output would enter batch-to-scalar class fanout"); - diag << " sourceClassId=" << sourceClass.id << " sourceKind=" << (sourceClass.isBatch ? "batch" : "scalar"); - diag << " liveExternalUse=" << (hasLiveExternalUseCached(state, originalOutput) ? "true" : "false"); - diag << " terminalHostBatch=" << (isTerminalHostBatchOutput(originalOutput, state.oldComputeOps) ? "true" : "false"); - diag << " originalDef=" - << (originalOutput.getDefiningOp() ? originalOutput.getDefiningOp()->getName().getStringRef() : StringRef("")); - - bool first = true; - diag << " destinationClasses=["; - auto destIt = state.producerDestClasses.find(keys.front()); - ArrayRef destinations = destIt == state.producerDestClasses.end() ? ArrayRef {} : ArrayRef(destIt->second); - for (ClassId classId : destinations) { - if (!first) - diag << ", "; - first = false; - const MaterializedClass& destClass = state.classes[classId]; - diag << classId << ":" << (destClass.isBatch ? "batch" : "scalar"); - } - diag << "]"; - - diag << " producerKeys=["; - first = true; - for (ProducerKey key : keys) { - if (!first) - diag << ", "; - first = false; - diag << key.instance.op->getName().getStringRef() << ":r" << key.resultIndex << ":laneStart=" << key.instance.laneStart - << ":laneCount=" << key.instance.laneCount; - } - diag << "]"; - return failure(); -} void appendDestinationClass(MaterializerState& state, ProducerKey key, ClassId classId) { SmallVector& destinations = state.producerDestClasses[key]; @@ -2659,8 +2560,16 @@ LogicalResult collectProducerDestinations(MaterializerState& state) { state, [&](CpuId, ClassId targetClass, ComputeInstance scheduledConsumer, ComputeInstance logicalConsumer, SlotId) -> LogicalResult { - for (Value input : getComputeInstanceInputs(scheduledConsumer)) { - for (ProducerKey producerKey : collectProducerKeysForDestinations(input, logicalConsumer)) { + SmallVector consumerInputs = getComputeInstanceInputs(scheduledConsumer); + for (auto [inputIndex, input] : llvm::enumerate(consumerInputs)) { + SmallVector producerKeys; + if (auto batchConsumer = dyn_cast(logicalConsumer.op)) + producerKeys = collectProducerKeysForBatchInputDestinations( + state, batchConsumer, static_cast(inputIndex), input, logicalConsumer); + else + producerKeys = collectProducerKeysForDestinations(input, logicalConsumer); + + for (ProducerKey producerKey : producerKeys) { ComputeInstance scheduledProducer = getScheduledChunkForLogicalInstance(state, producerKey.instance); auto producerCpuIt = state.schedule.computeToCpuMap.find(scheduledProducer); if (producerCpuIt == state.schedule.computeToCpuMap.end()) @@ -3030,6 +2939,159 @@ getProjectedInputSliceMatch(MaterializerState& state, SpatComputeBatch batch, un return match; } +FailureOr evaluateProjectionIndexLike(OpFoldResult value, Value laneArg, uint32_t lane); + +FailureOr evaluateProjectionIndexLike(Value value, Value laneArg, uint32_t lane) { + if (value == laneArg) + return static_cast(lane); + + if (std::optional constant = matchConstantIndexValue(value)) + return *constant; + + auto affineApply = value.getDefiningOp(); + if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) + return failure(); + + SmallVector operands; + operands.reserve(affineApply.getMapOperands().size()); + for (Value operand : affineApply.getMapOperands()) { + FailureOr evaluated = evaluateProjectionIndexLike(operand, laneArg, lane); + if (failed(evaluated)) + return failure(); + operands.push_back(IntegerAttr::get(IndexType::get(value.getContext()), *evaluated)); + } + + SmallVector results; + if (failed(affineApply.getAffineMap().constantFold(operands, results)) || results.size() != 1) + return failure(); + + auto intAttr = dyn_cast(results.front()); + if (!intAttr) + return failure(); + return intAttr.getInt(); +} + +FailureOr evaluateProjectionIndexLike(OpFoldResult value, Value laneArg, uint32_t lane) { + if (auto attr = llvm::dyn_cast(value)) { + auto intAttr = dyn_cast(attr); + if (!intAttr) + return failure(); + return intAttr.getInt(); + } + return evaluateProjectionIndexLike(llvm::cast(value), laneArg, lane); +} + +FailureOr +getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex) { + auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); + if (!inParallel) + return failure(); + + auto firstOutputArg = batch.getOutputArgument(0); + if (!firstOutputArg) + return failure(); + + for (Operation& op : inParallel.getRegion().front()) { + auto insert = dyn_cast(&op); + if (!insert) + continue; + + auto outputArg = dyn_cast(insert.getDest()); + if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) + continue; + + unsigned candidateIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); + if (candidateIndex == resultIndex) + return insert; + } + + return failure(); +} + +FailureOr> +evaluateStaticProjectionIndices(ArrayRef values, Value laneArg, uint32_t lane) { + SmallVector evaluated; + evaluated.reserve(values.size()); + for (OpFoldResult value : values) { + FailureOr index = evaluateProjectionIndexLike(value, laneArg, lane); + if (failed(index)) + return failure(); + evaluated.push_back(*index); + } + return evaluated; +} + + +bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch, + const AffineProjectedInputSliceMatch& match, + ProducerKey producer, + uint32_t consumerLane) { + auto producerBatch = dyn_cast_or_null(producer.instance.op); + if (!producerBatch) + return true; + + FailureOr producerProjection = + getBatchResultProjectionInsert(producerBatch, producer.resultIndex); + if (failed(producerProjection)) + return true; + + std::optional producerLaneArg = producerBatch.getLaneArgument(); + std::optional consumerLaneArg = consumerBatch.getLaneArgument(); + if (!producerLaneArg || !consumerLaneArg) + return false; + + SmallVector consumerSizes(match.fragmentShape.begin(), match.fragmentShape.end()); + SmallVector loopIterationIndices(match.loops.size(), 0); + + const auto consumerSliceFitsOneProducerFragment = [&]() -> bool { + SmallVector consumerOffsets; + consumerOffsets.reserve(match.offsets.size()); + for (OpFoldResult offset : match.offsets) { + FailureOr evaluated = + evaluateProjectedOffsetValue(offset, *consumerLaneArg, consumerLane, match.loops, loopIterationIndices); + if (failed(evaluated)) + return false; + consumerOffsets.push_back(*evaluated); + } + + uint32_t producerLaneEnd = producer.instance.laneStart + producer.instance.laneCount; + for (uint32_t producerLane = producer.instance.laneStart; producerLane < producerLaneEnd; ++producerLane) { + FailureOr> producerOffsets = + evaluateStaticProjectionIndices(producerProjection->getMixedOffsets(), *producerLaneArg, producerLane); + FailureOr> producerSizes = + evaluateStaticProjectionIndices(producerProjection->getMixedSizes(), *producerLaneArg, producerLane); + FailureOr> producerStrides = + evaluateStaticProjectionIndices(producerProjection->getMixedStrides(), *producerLaneArg, producerLane); + if (failed(producerOffsets) || failed(producerSizes) || failed(producerStrides)) + return false; + if (!areAllUnitStrides(*producerStrides)) + return false; + if (isStaticSliceContainedIn(consumerOffsets, consumerSizes, *producerOffsets, *producerSizes)) + return true; + } + + return false; + }; + + if (match.loops.empty()) + return consumerSliceFitsOneProducerFragment(); + + const auto recurse = [&](auto&& self, size_t loopIndex) -> bool { + if (loopIndex == match.loops.size()) + return consumerSliceFitsOneProducerFragment(); + + for (int64_t iteration = 0; iteration < match.loops[loopIndex].tripCount; ++iteration) { + loopIterationIndices[loopIndex] = iteration; + if (!self(self, loopIndex + 1)) + return false; + } + return true; + }; + + return recurse(recurse, 0); +} + + LogicalResult collectProjectedTransfers(MaterializerState& state) { struct PendingProjectedTransferDescriptor { ProjectedBatchInputKey inputKey; @@ -3359,609 +3421,10 @@ ArrayRef getDestinationClasses(MaterializerState& state, ProducerKey ke return it->second; } -std::optional getKnownMinimumIndexValue(Value value) { - if (std::optional constant = matchConstantIndexValue(value)) - return *constant; - - if (auto blockArg = dyn_cast(value)) { - if (blockArg.getArgNumber() == 0) { - if (auto loop = dyn_cast_or_null(blockArg.getOwner()->getParentOp())) - return matchConstantIndexValue(loop.getLowerBound()); - } - return std::nullopt; - } - - if (auto add = value.getDefiningOp()) { - std::optional lhs = getKnownMinimumIndexValue(add.getLhs()); - std::optional rhs = getKnownMinimumIndexValue(add.getRhs()); - if (lhs && rhs) - return *lhs + *rhs; - return std::nullopt; - } - - if (auto mul = value.getDefiningOp()) { - std::optional lhs = getKnownMinimumIndexValue(mul.getLhs()); - std::optional rhs = getKnownMinimumIndexValue(mul.getRhs()); - if (!lhs || !rhs) - return std::nullopt; - if (*lhs >= 0 && *rhs >= 0) - return *lhs * *rhs; - return std::nullopt; - } - - auto affineApply = value.getDefiningOp(); - if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) - return std::nullopt; - - SmallVector operands; - operands.reserve(affineApply.getMapOperands().size()); - for (Value operand : affineApply.getMapOperands()) { - std::optional minimum = getKnownMinimumIndexValue(operand); - if (!minimum) - return std::nullopt; - operands.push_back(IntegerAttr::get(IndexType::get(value.getContext()), *minimum)); - } - - SmallVector results; - if (failed(affineApply.getAffineMap().constantFold(operands, results)) || results.size() != 1) - return std::nullopt; - - auto intAttr = dyn_cast(results.front()); - if (!intAttr) - return std::nullopt; - return intAttr.getInt(); -} - -std::optional getKnownMinimumCommunicationChannelId(Operation* op) { - if (auto send = dyn_cast(op)) - return getKnownMinimumIndexValue(send.getChannelId()); - if (auto receive = dyn_cast(op)) - return getKnownMinimumIndexValue(receive.getChannelId()); - - std::optional minimum; - op->walk([&](Operation* nested) { - if (nested == op) - return; - std::optional nestedMinimum = getKnownMinimumCommunicationChannelId(nested); - if (!nestedMinimum) - return; - if (!minimum || *nestedMinimum < *minimum) - minimum = *nestedMinimum; - }); - return minimum; -} - -void setInsertionPointForScalarReceive(MaterializerState& state, - MaterializedClass& targetClass, - int64_t channelId) { - assert(!targetClass.isBatch && "scalar receive ordering expects a scalar target class"); - - for (Operation& op : *targetClass.body) { - if (op.hasTrait()) - break; - - std::optional existingChannel = getKnownMinimumCommunicationChannelId(&op); - if (existingChannel && *existingChannel > channelId) { - state.rewriter.setInsertionPoint(&op); - return; - } - } - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); -} - // ----------------------------------------------------------------------------- // Communication materialization helpers. // ----------------------------------------------------------------------------- -constexpr const char* kRaptorMinChannelIdAttr = "raptor.min_channel_id"; -constexpr const char* kRaptorMaterializerAttr = "raptor.materializer"; -constexpr const char* kRaptorCommTraceIdAttr = "raptor.comm_trace_id"; -constexpr const char* kRaptorCommTraceKindAttr = "raptor.comm_trace_kind"; -constexpr const char* kRaptorCommTracePhaseAttr = "raptor.comm_trace_phase"; -constexpr const char* kRaptorCommTraceClassIdAttr = "raptor.comm_trace_class_id"; -constexpr const char* kRaptorCommTraceClassKindAttr = "raptor.comm_trace_class_kind"; -constexpr const char* kRaptorCommTraceBlockOrdinalAttr = "raptor.comm_trace_block_ordinal"; -constexpr const char* kRaptorCommTracePayloadAttr = "raptor.comm_trace_payload"; -constexpr const char* kRaptorCommTraceMessagesAttr = "raptor.comm_trace_messages"; -constexpr const char* kRaptorCommTracePrevOpAttr = "raptor.comm_trace_prev_op"; -constexpr const char* kRaptorCommTraceNextOpAttr = "raptor.comm_trace_next_op"; - -int64_t getMinimumChannelId(ArrayRef channelIds) { - assert(!channelIds.empty() && "expected at least one channel id"); - int64_t minChannelId = channelIds.front(); - for (int64_t channelId : channelIds.drop_front()) - if (channelId < minChannelId) - minChannelId = channelId; - return minChannelId; -} - -SmallVector getScalarSendChannelOrder(const MessageVector& messages) { - SmallVector order; - order.reserve(messages.size()); - for (size_t i = 0, e = messages.size(); i < e; ++i) - order.push_back(i); - - llvm::sort(order, [&](size_t lhs, size_t rhs) { - if (messages.channelIds[lhs] != messages.channelIds[rhs]) - return messages.channelIds[lhs] < messages.channelIds[rhs]; - if (messages.sourceCoreIds[lhs] != messages.sourceCoreIds[rhs]) - return messages.sourceCoreIds[lhs] < messages.sourceCoreIds[rhs]; - return messages.targetCoreIds[lhs] < messages.targetCoreIds[rhs]; - }); - return order; -} - -MessageVector reorderMessages(const MessageVector& messages, ArrayRef order) { - MessageVector reordered; - reordered.channelIds.reserve(messages.size()); - reordered.sourceCoreIds.reserve(messages.size()); - reordered.targetCoreIds.reserve(messages.size()); - for (size_t index : order) - reordered.append(messages.channelIds[index], messages.sourceCoreIds[index], messages.targetCoreIds[index]); - return reordered; -} - -MessageVector reorderScalarSendMessagesByChannel(const MessageVector& messages) { - return reorderMessages(messages, getScalarSendChannelOrder(messages)); -} - -ProjectedTransferDescriptor reorderProjectedDescriptorByMessageOrder(const ProjectedTransferDescriptor& descriptor, - ArrayRef order) { - ProjectedTransferDescriptor reordered = descriptor; - size_t payloadFragmentCount = static_cast(descriptor.layout.payloadFragmentCount); - reordered.fragmentOffsets.clear(); - reordered.fragmentOffsets.reserve(descriptor.fragmentOffsets.size()); - for (size_t messageIndex : order) { - size_t offset = messageIndex * payloadFragmentCount; - for (size_t fragmentIndex = 0; fragmentIndex < payloadFragmentCount; ++fragmentIndex) - reordered.fragmentOffsets.push_back(descriptor.fragmentOffsets[offset + fragmentIndex]); - } - reordered.fragmentOffsetsByDim.clear(); - return reordered; -} - - -Operation* getPayloadDefiningOpInClassBlock(Value payload, MaterializedClass& materializedClass) { - Operation* definingOp = payload.getDefiningOp(); - if (!definingOp || definingOp->getBlock() != materializedClass.body) - return nullptr; - return definingOp; -} - -Operation* findScalarCommunicationInsertionPoint(MaterializedClass& materializedClass, - int64_t minChannelId, - Operation* lowerBound = nullptr) { - Operation* terminator = materializedClass.body->getTerminator(); - bool afterLowerBound = lowerBound == nullptr; - - for (Operation& op : *materializedClass.body) { - if (&op == terminator) - break; - - if (!afterLowerBound) { - if (&op == lowerBound) - afterLowerBound = true; - continue; - } - - if (&op == lowerBound) - continue; - - auto existingMinChannel = op.getAttrOfType(kRaptorMinChannelIdAttr); - if (existingMinChannel && existingMinChannel.getInt() > minChannelId) - return &op; - } - - return terminator; -} - -void setInsertionPointForScalarCommunication(MaterializerState& state, - MaterializedClass& materializedClass, - int64_t minChannelId, - Operation* lowerBound = nullptr) { - state.rewriter.setInsertionPoint( - findScalarCommunicationInsertionPoint(materializedClass, minChannelId, lowerBound)); -} - -constexpr const char kRaptorCommOrderAttr[] = "raptor.comm_order"; - -int64_t computeBlockingCommunicationOrderKey(int32_t sourceCoreId, int32_t targetCoreId, int64_t channelId) { - int64_t lowCore = std::min(sourceCoreId, targetCoreId); - int64_t highCore = std::max(sourceCoreId, targetCoreId); - int64_t directionPhase = sourceCoreId <= targetCoreId ? 0 : 1; - return (((lowCore * 1000000LL + highCore) * 2LL + directionPhase) * 1000000000LL) + channelId; -} - -int64_t getMinimumBlockingCommunicationOrderKey(const MessageVector& messages) { - assert(!messages.empty() && "expected at least one message"); - int64_t best = computeBlockingCommunicationOrderKey( - messages.sourceCoreIds.front(), messages.targetCoreIds.front(), messages.channelIds.front()); - for (size_t index = 1, end = messages.size(); index < end; ++index) { - best = std::min(best, computeBlockingCommunicationOrderKey( - messages.sourceCoreIds[index], messages.targetCoreIds[index], messages.channelIds[index])); - } - return best; -} - -Operation* findScalarCommunicationInsertionPointByOrder(MaterializedClass& materializedClass, - int64_t orderKey, - int64_t minChannelId, - Operation* lowerBound = nullptr) { - Operation* terminator = materializedClass.body->getTerminator(); - bool afterLowerBound = lowerBound == nullptr; - - for (Operation& op : *materializedClass.body) { - if (&op == terminator) - break; - - if (!afterLowerBound) { - if (&op == lowerBound) - afterLowerBound = true; - continue; - } - - if (&op == lowerBound) - continue; - - if (auto existingOrder = op.getAttrOfType(kRaptorCommOrderAttr)) { - if (existingOrder.getInt() > orderKey) - return &op; - continue; - } - - auto existingMinChannel = op.getAttrOfType(kRaptorMinChannelIdAttr); - if (existingMinChannel && existingMinChannel.getInt() > minChannelId) - return &op; - } - - return terminator; -} - -void setInsertionPointForScalarCommunicationOrder(MaterializerState& state, - MaterializedClass& materializedClass, - int64_t orderKey, - int64_t minChannelId, - Operation* lowerBound = nullptr) { - if (!pimMaterializeScalarFanoutGlobalOrder) { - setInsertionPointForScalarCommunication(state, materializedClass, minChannelId, lowerBound); - return; - } - - state.rewriter.setInsertionPoint( - findScalarCommunicationInsertionPointByOrder(materializedClass, orderKey, minChannelId, lowerBound)); -} - -void markScalarCommunication(Operation* op, int64_t minChannelId, StringRef materializer = StringRef()) { - if (!op) - return; - op->setAttr(kRaptorMinChannelIdAttr, - IntegerAttr::get(IndexType::get(op->getContext()), minChannelId)); - if (!materializer.empty()) - op->setAttr(kRaptorMaterializerAttr, StringAttr::get(op->getContext(), materializer)); -} - -void markScalarCommunicationOrder(Operation* op, int64_t orderKey) { - if (!op) - return; - op->setAttr(kRaptorCommOrderAttr, IntegerAttr::get(IndexType::get(op->getContext()), orderKey)); -} - -std::optional getOperationOrdinalInBlock(Operation* op) { - if (!op || !op->getBlock()) - return std::nullopt; - - int64_t ordinal = 0; - for (Operation& candidate : *op->getBlock()) { - if (&candidate == op) - return ordinal; - ++ordinal; - } - return std::nullopt; -} - -std::string formatOperationForTrace(Operation* op) { - if (!op) - return ""; - - std::string text; - llvm::raw_string_ostream os(text); - os << op->getName().getStringRef(); - if (auto ordinal = getOperationOrdinalInBlock(op)) - os << "@" << *ordinal; - return os.str(); -} - -std::string formatValueForTrace(Value value, Block* localBody) { - if (!value) - return ""; - - std::string text; - llvm::raw_string_ostream os(text); - if (auto arg = dyn_cast(value)) { - os << "block_arg#" << arg.getArgNumber(); - return os.str(); - } - - Operation* definingOp = value.getDefiningOp(); - if (!definingOp) { - os << "external"; - return os.str(); - } - - os << definingOp->getName().getStringRef(); - if (definingOp->getBlock() == localBody) { - if (auto ordinal = getOperationOrdinalInBlock(definingOp)) - os << "@" << *ordinal; - } - else { - os << "@external-block"; - } - return os.str(); -} - -std::string formatClassForTrace(const MaterializedClass& materializedClass) { - std::string text; - llvm::raw_string_ostream os(text); - os << (materializedClass.isBatch ? "batch" : "scalar") << " class " << materializedClass.id << " cpus=["; - for (auto [index, cpu] : llvm::enumerate(materializedClass.cpus)) { - if (index) - os << ","; - os << cpu; - } - os << "]"; - return os.str(); -} - -std::string formatMessagesForTrace(const MessageVector& messages, unsigned maxMessages = 8) { - std::string text; - llvm::raw_string_ostream os(text); - os << "count=" << messages.size() << " ["; - unsigned limit = std::min(maxMessages, messages.size()); - for (unsigned index = 0; index < limit; ++index) { - if (index) - os << "; "; - os << "c" << messages.channelIds[index] << ":" << messages.sourceCoreIds[index] - << "->" << messages.targetCoreIds[index]; - } - if (messages.size() > limit) - os << "; ..."; - os << "]"; - return os.str(); -} - -void annotateCommunicationMaterialization(MaterializerState& state, - MaterializedClass& materializedClass, - Operation* op, - StringRef kind, - StringRef materializer, - StringRef phase, - std::optional minChannelId, - std::optional orderKey, - Value payload = Value(), - const MessageVector* messages = nullptr) { - if (!op) - return; - - MLIRContext* context = op->getContext(); - int64_t traceId = state.nextCommunicationTraceId++; - auto indexType = IndexType::get(context); - op->setAttr(kRaptorCommTraceIdAttr, IntegerAttr::get(indexType, traceId)); - op->setAttr(kRaptorCommTraceKindAttr, StringAttr::get(context, kind)); - op->setAttr(kRaptorCommTracePhaseAttr, StringAttr::get(context, phase)); - op->setAttr(kRaptorCommTraceClassIdAttr, IntegerAttr::get(indexType, materializedClass.id)); - op->setAttr(kRaptorCommTraceClassKindAttr, - StringAttr::get(context, materializedClass.isBatch ? "batch" : "scalar")); - if (!materializer.empty()) - op->setAttr(kRaptorMaterializerAttr, StringAttr::get(context, materializer)); - if (minChannelId) - op->setAttr(kRaptorMinChannelIdAttr, IntegerAttr::get(indexType, *minChannelId)); - if (orderKey) - op->setAttr(kRaptorCommOrderAttr, IntegerAttr::get(indexType, *orderKey)); - if (auto ordinal = getOperationOrdinalInBlock(op)) - op->setAttr(kRaptorCommTraceBlockOrdinalAttr, IntegerAttr::get(indexType, *ordinal)); - op->setAttr(kRaptorCommTracePayloadAttr, - StringAttr::get(context, formatValueForTrace(payload, materializedClass.body))); - if (messages) - op->setAttr(kRaptorCommTraceMessagesAttr, StringAttr::get(context, formatMessagesForTrace(*messages))); - - Operation* prev = op->getPrevNode(); - Operation* next = op->getNextNode(); - op->setAttr(kRaptorCommTracePrevOpAttr, StringAttr::get(context, formatOperationForTrace(prev))); - op->setAttr(kRaptorCommTraceNextOpAttr, StringAttr::get(context, formatOperationForTrace(next))); - - if (!pimTraceCommunicationMaterialization) - return; - - llvm::errs() << "[raptor:comm-materializer] #" << traceId << " " << kind - << " via " << materializer << " phase=" << phase << " " - << formatClassForTrace(materializedClass); - if (minChannelId) - llvm::errs() << " min_channel=" << *minChannelId; - if (orderKey) - llvm::errs() << " order=" << *orderKey; - if (auto ordinal = getOperationOrdinalInBlock(op)) - llvm::errs() << " block_ordinal=" << *ordinal; - llvm::errs() << " payload=" << formatValueForTrace(payload, materializedClass.body); - if (messages) - llvm::errs() << " messages=" << formatMessagesForTrace(*messages); - llvm::errs() << " prev=" << formatOperationForTrace(prev) - << " next=" << formatOperationForTrace(next) << "\n"; -} - -void setInsertionPointForEarlyCommunication(MaterializerState& state, MaterializedClass& materializedClass) { - auto lateIt = state.firstLateCommunicationOps.find(materializedClass.id); - if (lateIt != state.firstLateCommunicationOps.end() && lateIt->second && lateIt->second->getBlock()) { - state.rewriter.setInsertionPoint(lateIt->second); - return; - } - - state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); -} - -void setInsertionPointForLateCommunication(MaterializerState& state, MaterializedClass& materializedClass) { - state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); -} - - -Operation* findLateScalarCommunicationInsertionPoint(MaterializerState& state, - MaterializedClass& materializedClass, - int64_t minChannelId) { - Operation* terminator = materializedClass.body->getTerminator(); - auto lateIt = state.firstLateCommunicationOps.find(materializedClass.id); - Operation* firstLate = lateIt == state.firstLateCommunicationOps.end() ? nullptr : lateIt->second; - if (!firstLate || firstLate->getBlock() != materializedClass.body) - return terminator; - - bool inLateRegion = false; - for (Operation& op : *materializedClass.body) { - if (&op == terminator) - break; - - if (!inLateRegion) { - if (&op == firstLate) - inLateRegion = true; - else - continue; - } - - auto existingMinChannel = op.getAttrOfType(kRaptorMinChannelIdAttr); - if (existingMinChannel && existingMinChannel.getInt() > minChannelId) - return &op; - } - - return terminator; -} - -void setInsertionPointForLateScalarCommunication(MaterializerState& state, - MaterializedClass& materializedClass, - int64_t minChannelId) { - state.rewriter.setInsertionPoint( - findLateScalarCommunicationInsertionPoint(state, materializedClass, minChannelId)); -} - -void rememberLateCommunicationOp(MaterializerState& state, MaterializedClass& materializedClass, Operation* op) { - if (!op || op->getBlock() != materializedClass.body) - return; - - Operation*& firstLate = state.firstLateCommunicationOps[materializedClass.id]; - if (!firstLate || firstLate->getBlock() != materializedClass.body || op->isBeforeInBlock(firstLate)) - firstLate = op; -} - - - -constexpr const char kMinCommunicationChannelIdAttr[] = "raptor.min_channel_id"; - -std::optional getConstantIndexValue(Value value) { - APInt constant; - if (matchPattern(value, m_ConstantInt(&constant))) - return constant.getSExtValue(); - return std::nullopt; -} - -std::optional getCommunicationChannelId(Operation& op) { - if (auto attr = op.getAttrOfType(kMinCommunicationChannelIdAttr)) - return attr.getInt(); - - if (auto send = dyn_cast(&op)) - return getConstantIndexValue(send.getChannelId()); - if (auto receive = dyn_cast(&op)) - return getConstantIndexValue(receive.getChannelId()); - - return std::nullopt; -} - -int64_t getMinimumCommunicationChannelId(const MessageVector& messages) { - assert(!messages.empty() && "expected at least one message"); - return *std::min_element(messages.channelIds.begin(), messages.channelIds.end()); -} - -void markCommunicationChannelId(Operation* op, int64_t channelId) { - if (!op) - return; - op->setAttr(kMinCommunicationChannelIdAttr, - IntegerAttr::get(IntegerType::get(op->getContext(), 64), channelId)); -} - -Operation* getSameBlockDefiningOp(Value value, Block* block) { - Operation* definingOp = value.getDefiningOp(); - if (!definingOp || definingOp->getBlock() != block) - return nullptr; - return definingOp; -} - - -bool valueDependsOnChannelReceive(Value root) { - SmallVector worklist; - DenseSet visitedValues; - DenseSet visitedOps; - worklist.push_back(root); - - auto visitOperand = [&](Value value) { - if (value && visitedValues.insert(value).second) - worklist.push_back(value); - }; - - while (!worklist.empty()) { - Value value = worklist.pop_back_val(); - Operation* definingOp = value.getDefiningOp(); - if (!definingOp || !visitedOps.insert(definingOp).second) - continue; - - if (isa(definingOp)) - return true; - - for (Value operand : definingOp->getOperands()) - visitOperand(operand); - - for (Region& region : definingOp->getRegions()) { - for (Block& block : region) { - for (Operation& nested : block) { - for (Value operand : nested.getOperands()) - visitOperand(operand); - } - } - } - } - - return false; -} - -bool shouldDelayScalarSendUntilAfterReceives(Value payload, int32_t sourceCoreId, int32_t targetCoreId) { - if (sourceCoreId <= targetCoreId) - return false; - return valueDependsOnChannelReceive(payload); -} - -void partitionScalarMessagesByReceiveDependency(Value payload, - const MessageVector& messages, - MessageVector& earlyMessages, - MessageVector& lateMessages) { - for (size_t i = 0, e = messages.size(); i < e; ++i) { - MessageVector& bucket = shouldDelayScalarSendUntilAfterReceives( - payload, messages.sourceCoreIds[i], messages.targetCoreIds[i]) - ? lateMessages - : earlyMessages; - bucket.append(messages.channelIds[i], messages.sourceCoreIds[i], messages.targetCoreIds[i]); - } -} - -void setInsertionPointForScalarSend(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - int64_t minChannelId, - bool late) { - if (late) { - setInsertionPointForLateScalarCommunication(state, sourceClass, minChannelId); - return; - } - - setInsertionPointForScalarCommunication( - state, sourceClass, minChannelId, getPayloadDefiningOpInClassBlock(payload, sourceClass)); -} - - void appendScalarSend(MaterializerState& state, MaterializedClass& sourceClass, Value payload, @@ -3971,43 +3434,24 @@ void appendScalarSend(MaterializerState& state, Location loc) { assert(!sourceClass.isBatch && "scalar send helper expects a scalar source class"); - bool late = shouldDelayScalarSendUntilAfterReceives(payload, sourceCoreId, targetCoreId); - int64_t orderKey = computeBlockingCommunicationOrderKey(sourceCoreId, targetCoreId, channelId); - if (pimMaterializeScalarFanoutGlobalOrder) - setInsertionPointForScalarCommunicationOrder( - state, sourceClass, orderKey, channelId, getPayloadDefiningOpInClassBlock(payload, sourceClass)); - else - setInsertionPointForScalarSend(state, sourceClass, payload, channelId, late); + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, channelId); Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, sourceCoreId); Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, targetCoreId); - auto send = SpatChannelSendOp::create( - state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload); - markScalarCommunication(send.getOperation(), channelId, "appendScalarSend"); - markScalarCommunicationOrder(send.getOperation(), orderKey); - MessageVector traceMessages; - traceMessages.append(channelId, sourceCoreId, targetCoreId); - annotateCommunicationMaterialization(state, - sourceClass, - send.getOperation(), - "send", - "appendScalarSend", - late ? "late" : (pimMaterializeScalarFanoutGlobalOrder ? "global" : "early"), - channelId, - orderKey, - payload, - &traceMessages); - if (late && !pimMaterializeScalarFanoutGlobalOrder) - rememberLateCommunicationOp(state, sourceClass, send.getOperation()); + SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload); } -LogicalResult emitScalarSendLoopAtInsertionPoint(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const MessageVector& messages, - int64_t minChannelId, - int64_t orderKey, - Location loc) { +LogicalResult appendScalarSendLoop(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const MessageVector& messages, + Location loc) { + assert(!sourceClass.isBatch && "scalar send loop expects a scalar source class"); + assert(messages.size() > 1 && "send loop is only useful for multiple sends"); + assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); + + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); Value upperBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(messages.size())); @@ -4029,51 +3473,9 @@ LogicalResult emitScalarSendLoopAtInsertionPoint(MaterializerState& state, }); if (failed(sendLoop)) return failure(); - markScalarCommunication(sendLoop->loop.getOperation(), minChannelId, "appendScalarSendLoop"); - markScalarCommunicationOrder(sendLoop->loop.getOperation(), orderKey); - annotateCommunicationMaterialization(state, - sourceClass, - sendLoop->loop.getOperation(), - "send-loop", - "appendScalarSendLoop", - "loop", - minChannelId, - orderKey, - payload, - &messages); return success(); } -LogicalResult appendScalarSendLoop(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const MessageVector& messages, - Location loc) { - assert(!sourceClass.isBatch && "scalar send loop expects a scalar source class"); - assert(messages.size() > 1 && "send loop is only useful for multiple sends"); - assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); - - MessageVector orderedMessages = reorderScalarSendMessagesByChannel(messages); - if (pimMaterializeScalarFanoutGlobalOrder) { - for (size_t index = 0, end = orderedMessages.size(); index < end; ++index) - appendScalarSend(state, - sourceClass, - payload, - orderedMessages.channelIds[index], - orderedMessages.sourceCoreIds[index], - orderedMessages.targetCoreIds[index], - loc); - return success(); - } - - int64_t minChannelId = getMinimumChannelId(orderedMessages.channelIds); - int64_t orderKey = getMinimumBlockingCommunicationOrderKey(orderedMessages); - setInsertionPointForScalarCommunicationOrder( - state, sourceClass, orderKey, minChannelId, getPayloadDefiningOpInClassBlock(payload, sourceClass)); - return emitScalarSendLoopAtInsertionPoint(state, sourceClass, payload, orderedMessages, minChannelId, orderKey, loc); -} - - FailureOr buildProjectedPackedPayload(MaterializerState& state, MaterializedClass& targetClass, Value fullPayload, @@ -4085,6 +3487,18 @@ FailureOr buildProjectedPackedPayload(MaterializerState& state, if (descriptor.layout.payloadFragmentCount == 1) return targetClass.op->emitError("projected packed payload builder expects a packed payload"); + FailureOr localizedPayload = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + fullPayload, + targetClass.op, + "projected packed payload tried to reuse a tensor from another materialized class"); + if (failed(localizedPayload)) + return failure(); + FailureOr localizedMessageIndex = rematerializeIndexValueInClass(state, targetClass, messageIndex, loc); + if (failed(localizedMessageIndex)) + return failure(); + Value init = tensor::EmptyOp::create( state.rewriter, loc, descriptor.payloadType.getShape(), descriptor.payloadType.getElementType()) .getResult(); @@ -4104,10 +3518,7 @@ FailureOr buildProjectedPackedPayload(MaterializerState& state, Value acc = iterArgs.front(); Value payloadFragmentCount = getOrCreateIndexConstant(state.constantFolder, targetClass.op, descriptor.layout.payloadFragmentCount); - FailureOr localMessageIndex = rematerializeIndexValueInClass(state, targetClass, messageIndex, loc); - if (failed(localMessageIndex)) - return failure(); - Value flatBase = arith::MulIOp::create(state.rewriter, loc, *localMessageIndex, payloadFragmentCount).getResult(); + Value flatBase = arith::MulIOp::create(state.rewriter, loc, *localizedMessageIndex, payloadFragmentCount).getResult(); Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult(); FailureOr> fragmentOffsets = @@ -4115,12 +3526,12 @@ FailureOr buildProjectedPackedPayload(MaterializerState& state, if (failed(fragmentOffsets)) return failure(); FailureOr fragment = createStaticExtractSliceInClass( - state, targetClass, loc, fullPayload, *fragmentOffsets, descriptor.layout.fragmentShape); + state, targetClass, loc, *localizedPayload, *fragmentOffsets, descriptor.layout.fragmentShape); if (failed(fragment)) return failure(); - FailureOr packedOffset = scaleIndexByDim0SizeInClass( - state, targetClass, fragmentIndex, descriptor.layout.fragmentType.getDimSize(0), loc); + FailureOr packedOffset = + scaleIndexByDim0SizeInClass(state, targetClass, fragmentIndex, descriptor.layout.fragmentType.getDimSize(0), loc); if (failed(packedOffset)) return failure(); FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, *fragment, acc, *packedOffset); @@ -4143,16 +3554,25 @@ FailureOr buildProjectedPayloadForMessage(MaterializerState& state, if (failed(verifyProjectedTransferDescriptor(targetClass.op, descriptor))) return failure(); + FailureOr localizedPayload = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + fullPayload, + targetClass.op, + "projected payload tried to reuse a tensor from another materialized class"); + if (failed(localizedPayload)) + return failure(); + if (descriptor.layout.payloadFragmentCount == 1) { FailureOr> fragmentOffsets = buildProjectedFragmentOffsetsInClass(state, targetClass, descriptor, messageIndex, loc); if (failed(fragmentOffsets)) return failure(); return createStaticExtractSliceInClass( - state, targetClass, loc, fullPayload, *fragmentOffsets, descriptor.layout.fragmentShape); + state, targetClass, loc, *localizedPayload, *fragmentOffsets, descriptor.layout.fragmentShape); } - return buildProjectedPackedPayload(state, targetClass, fullPayload, descriptor, messageIndex, loc); + return buildProjectedPackedPayload(state, targetClass, *localizedPayload, descriptor, messageIndex, loc); } LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, @@ -4163,59 +3583,27 @@ LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, Location loc) { assert(!sourceClass.isBatch && "projected scalar send expects scalar source class"); assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); - - SmallVector messageOrder = getScalarSendChannelOrder(messages); - MessageVector orderedMessages = reorderMessages(messages, messageOrder); - ProjectedTransferDescriptor orderedDescriptor = reorderProjectedDescriptorByMessageOrder(descriptor, messageOrder); - if (failed(finalizeProjectedTransferDescriptor(sourceClass.op, orderedDescriptor))) - return failure(); - if (failed(verifyProjectedSendDescriptor(sourceClass.op, orderedDescriptor, orderedMessages))) + if (failed(verifyProjectedSendDescriptor(sourceClass.op, descriptor, messages))) return failure(); - int64_t minChannelId = getMinimumChannelId(orderedMessages.channelIds); - int64_t orderKey = getMinimumBlockingCommunicationOrderKey(orderedMessages); - setInsertionPointForScalarCommunicationOrder( - state, sourceClass, orderKey, minChannelId, getPayloadDefiningOpInClassBlock(payload, sourceClass)); + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - if (orderedMessages.size() == 1 || pimMaterializeScalarFanoutGlobalOrder) { - for (size_t index = 0, end = orderedMessages.size(); index < end; ++index) { - int64_t channel = orderedMessages.channelIds[index]; - int32_t sourceCore = orderedMessages.sourceCoreIds[index]; - int32_t targetCore = orderedMessages.targetCoreIds[index]; - int64_t localOrderKey = computeBlockingCommunicationOrderKey(sourceCore, targetCore, channel); - setInsertionPointForScalarCommunicationOrder( - state, sourceClass, localOrderKey, channel, getPayloadDefiningOpInClassBlock(payload, sourceClass)); - - Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, channel); - Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, sourceCore); - Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, targetCore); - Value messageIndex = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(index)); - FailureOr sendPayload = - buildProjectedPayloadForMessage(state, sourceClass, payload, orderedDescriptor, messageIndex, loc); - if (failed(sendPayload)) - return failure(); - auto send = SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, *sendPayload); - markScalarCommunication(send.getOperation(), channel, "appendProjectedScalarSendLoop.single"); - markScalarCommunicationOrder(send.getOperation(), localOrderKey); - MessageVector traceMessages; - traceMessages.append(channel, sourceCore, targetCore); - annotateCommunicationMaterialization(state, - sourceClass, - send.getOperation(), - "send", - "appendProjectedScalarSendLoop.single", - "projected-single", - channel, - localOrderKey, - *sendPayload, - &traceMessages); - } + if (messages.size() == 1) { + Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.channelIds.front()); + Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.sourceCoreIds.front()); + Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.targetCoreIds.front()); + Value messageIndex = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); + FailureOr sendPayload = + buildProjectedPayloadForMessage(state, sourceClass, payload, descriptor, messageIndex, loc); + if (failed(sendPayload)) + return failure(); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, *sendPayload); return success(); } Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); Value upperBound = - getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(orderedMessages.size())); + getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(messages.size())); Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); auto projectedSendLoop = buildNormalizedScfFor( @@ -4226,11 +3614,11 @@ LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, step, ValueRange {}, [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { - Value channelId = createIndexedChannelId(state, sourceClass.op, orderedMessages, index, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, orderedMessages, index, loc); - Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, orderedMessages, index, loc); + Value channelId = createIndexedChannelId(state, sourceClass.op, messages, index, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, messages, index, loc); + Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, messages, index, loc); FailureOr sendPayload = - buildProjectedPayloadForMessage(state, sourceClass, payload, orderedDescriptor, index, loc); + buildProjectedPayloadForMessage(state, sourceClass, payload, descriptor, index, loc); if (failed(sendPayload)) return failure(); SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, *sendPayload); @@ -4238,22 +3626,9 @@ LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, }); if (failed(projectedSendLoop)) return failure(); - markScalarCommunication(projectedSendLoop->loop.getOperation(), minChannelId, "appendProjectedScalarSendLoop.loop"); - markScalarCommunicationOrder(projectedSendLoop->loop.getOperation(), orderKey); - annotateCommunicationMaterialization(state, - sourceClass, - projectedSendLoop->loop.getOperation(), - "send-loop", - "appendProjectedScalarSendLoop.loop", - "projected-loop", - minChannelId, - orderKey, - payload, - &orderedMessages); return success(); } - LogicalResult appendSend(MaterializerState& state, MaterializedClass& sourceClass, Value payload, @@ -4268,21 +3643,7 @@ LogicalResult appendSend(MaterializerState& state, Value channelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc); Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc); Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.targetCoreIds, loc); - auto send = SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); - int64_t minChannelId = getMinimumChannelId(messages.channelIds); - int64_t orderKey = getMinimumBlockingCommunicationOrderKey(messages); - markScalarCommunication(send.getOperation(), minChannelId, "appendSend.batch"); - markScalarCommunicationOrder(send.getOperation(), orderKey); - annotateCommunicationMaterialization(state, - sourceClass, - send.getOperation(), - "send", - "appendSend.batch", - "batch-lane-indexed", - minChannelId, - orderKey, - payload, - &messages); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); return success(); } @@ -4306,74 +3667,29 @@ Value appendScalarReceive(MaterializerState& state, int64_t channelId, int32_t sourceCoreId, int32_t targetCoreId, - Location loc, - bool lateReceive = false) { + Location loc) { assert(!targetClass.isBatch && "scalar receive helper expects a scalar target class"); - int64_t orderKey = computeBlockingCommunicationOrderKey(sourceCoreId, targetCoreId, channelId); - if (lateReceive) - setInsertionPointForLateScalarCommunication(state, targetClass, channelId); - else - setInsertionPointForScalarCommunicationOrder(state, targetClass, orderKey, channelId); + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, channelId); Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, sourceCoreId); Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, targetCoreId); - auto receive = SpatChannelReceiveOp::create( - state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue); - markScalarCommunication(receive.getOperation(), channelId, - lateReceive ? "appendScalarReceive.late" : "appendScalarReceive"); - markScalarCommunicationOrder(receive.getOperation(), orderKey); - MessageVector traceMessages; - traceMessages.append(channelId, sourceCoreId, targetCoreId); - annotateCommunicationMaterialization(state, - targetClass, - receive.getOperation(), - "receive", - lateReceive ? "appendScalarReceive.late" : "appendScalarReceive", - lateReceive ? "late" : (pimMaterializeScalarFanoutGlobalOrder ? "global" : "early"), - channelId, - orderKey, - Value(), - &traceMessages); - return receive.getOutput(); + return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue) + .getOutput(); } - Value appendReceive( - MaterializerState& state, - MaterializedClass& targetClass, - Type type, - const MessageVector& messages, - Location loc, - bool lateReceive = false) { + MaterializerState& state, MaterializedClass& targetClass, Type type, const MessageVector& messages, Location loc) { assert(succeeded(messages.verify(targetClass.op)) && "message metadata is inconsistent"); assert(!messages.empty() && "expected at least one receive"); - if (lateReceive) - setInsertionPointForLateScalarCommunication(state, targetClass, getMinimumChannelId(messages.channelIds)); - else - setInsertionPointForEarlyCommunication(state, targetClass); + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); if (targetClass.isBatch) { Value channelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc); Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc); Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc); - auto receive = SpatChannelReceiveOp::create(state.rewriter, loc, type, channelId, sourceCoreId, targetCoreId); - int64_t minChannelId = getMinimumChannelId(messages.channelIds); - int64_t orderKey = getMinimumBlockingCommunicationOrderKey(messages); - markScalarCommunication(receive.getOperation(), minChannelId, "appendReceive.batch"); - markScalarCommunicationOrder(receive.getOperation(), orderKey); - annotateCommunicationMaterialization(state, - targetClass, - receive.getOperation(), - "receive", - "appendReceive.batch", - lateReceive ? "late-batch" : "early-batch", - minChannelId, - orderKey, - Value(), - &messages); - return receive.getOutput(); + return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelId, sourceCoreId, targetCoreId).getOutput(); } assert(messages.size() == 1 && "scalar target class can only receive one message at a time"); @@ -4383,141 +3699,7 @@ Value appendReceive( messages.channelIds.front(), messages.sourceCoreIds.front(), messages.targetCoreIds.front(), - loc, - lateReceive); -} - -Value appendScalarReceiveAtCurrentInsertionPoint(MaterializerState& state, - MaterializedClass& targetClass, - Type type, - int64_t channelId, - int32_t sourceCoreId, - int32_t targetCoreId, - Location loc) { - assert(!targetClass.isBatch && "demand scalar receive expects a scalar target class"); - - int64_t orderKey = computeBlockingCommunicationOrderKey(sourceCoreId, targetCoreId, channelId); - Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, channelId); - Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, sourceCoreId); - Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, targetCoreId); - auto receive = SpatChannelReceiveOp::create( - state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue); - markScalarCommunication(receive.getOperation(), channelId, "appendScalarReceive.demand"); - markScalarCommunicationOrder(receive.getOperation(), orderKey); - MessageVector traceMessages; - traceMessages.append(channelId, sourceCoreId, targetCoreId); - annotateCommunicationMaterialization(state, - targetClass, - receive.getOperation(), - "receive", - "appendScalarReceive.demand", - "demand", - channelId, - orderKey, - Value(), - &traceMessages); - return receive.getOutput(); -} - -std::optional lookupPendingScalarReceiveIndex(MaterializerState& state, - ProducerKey key, - ClassId targetClassId) { - auto keyIt = state.pendingScalarReceiveLookup.find(key); - if (keyIt == state.pendingScalarReceiveLookup.end()) - return std::nullopt; - - auto classIt = keyIt->second.find(targetClassId); - if (classIt == keyIt->second.end()) - return std::nullopt; - return classIt->second; -} - -void recordPendingScalarReceive(MaterializerState& state, - ClassId targetClassId, - ArrayRef keys, - Type receiveType, - const MessageVector& messages, - Location loc) { - if (keys.empty()) - return; - - if (lookupPendingScalarReceiveIndex(state, keys.front(), targetClassId)) - return; - - size_t recordIndex = state.pendingScalarReceives.size(); - state.pendingScalarReceives.emplace_back(keys, targetClassId, receiveType, messages, loc); - - for (ProducerKey key : keys) - state.pendingScalarReceiveLookup[key][targetClassId] = recordIndex; -} - -FailureOr materializePendingScalarReceive(MaterializerState& state, - MaterializedClass& targetClass, - size_t recordIndex, - Location loc) { - if (recordIndex >= state.pendingScalarReceives.size()) - return targetClass.op->emitError("pending scalar receive index is out of bounds"); - - PendingScalarReceiveRecord& record = state.pendingScalarReceives[recordIndex]; - if (record.targetClassId != targetClass.id) - return targetClass.op->emitError("pending scalar receive target class mismatch"); - - if (record.materialized) - return record.value; - - if (targetClass.isBatch) - return targetClass.op->emitError("pending scalar receive cannot materialize into a batch class"); - if (record.messages.size() != 1) - return targetClass.op->emitError("pending scalar receive expected exactly one scalar message"); - - Location receiveLoc = loc; - Value received = appendScalarReceiveAtCurrentInsertionPoint(state, - targetClass, - record.receiveType, - record.messages.channelIds.front(), - record.messages.sourceCoreIds.front(), - record.messages.targetCoreIds.front(), - receiveLoc); - record.materialized = true; - record.value = received; - - for (ProducerKey key : record.keys) - state.availableValues.record(key, targetClass.id, received); - - return received; -} - - -LogicalResult materializePendingScalarReceivesForWholeBatchInput(MaterializerState& state, - MaterializedClass& targetClass, - ProducerKey wholeBatchKey, - Location loc) { - if (targetClass.isBatch || !isWholeBatchProducerKey(wholeBatchKey)) - return success(); - - SmallVector pendingIndices; - for (auto [recordIndex, record] : llvm::enumerate(state.pendingScalarReceives)) { - if (record.targetClassId != targetClass.id || record.materialized) - continue; - - bool contributesToWholeBatch = llvm::any_of(record.keys, [&](ProducerKey fragmentKey) { - return containsProducerKey(wholeBatchKey, fragmentKey); - }); - if (contributesToWholeBatch) - pendingIndices.push_back(recordIndex); - } - - if (pendingIndices.empty()) - return success(); - - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - for (size_t recordIndex : pendingIndices) { - FailureOr received = materializePendingScalarReceive(state, targetClass, recordIndex, loc); - if (failed(received)) - return failure(); - } - - return success(); + loc); } LogicalResult registerLazyPackedScalarReceives(MaterializerState& state, @@ -4576,9 +3758,12 @@ LogicalResult registerLazyPackedScalarReceives(MaterializerState& state, packedRun.messages = std::move(messages); - PackedScalarRunSlot slot; - llvm::append_range(slot.keys, keys); - packedRun.slots.push_back(std::move(slot)); + packedRun.slots.reserve(keys.size()); + for (ProducerKey key : keys) { + PackedScalarRunSlot slot; + slot.keys.push_back(key); + packedRun.slots.push_back(std::move(slot)); + } if (failed(validatePackedScalarRunMetadata(targetClass.op, packedRun))) return failure(); @@ -4593,7 +3778,6 @@ struct ScalarSourceReceivePlan { Type receiveType; Operation* projectedExtractOp = nullptr; ProjectedFragmentLayout projectedLayout; - std::optional projectedDescriptor; }; struct ProjectedScalarSendGroup { @@ -4715,19 +3899,12 @@ FailureOr buildScalarSourceFanoutPlan(MaterializerState& if (*descriptor) { const ProjectedTransferDescriptor& projectedDescriptor = **descriptor; - if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType()) { - if (!fanoutPlan.ordinaryMessages) - fanoutPlan.ordinaryMessages = MessageVector {}; - fanoutPlan.ordinaryMessages->append( - receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds); - fanoutPlan.receivePlans.push_back(std::move(receivePlan)); - continue; - } + if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType()) + return targetClass.op->emitError("scalar projected receive unexpectedly uses the full producer tensor type"); receivePlan.receiveType = projectedDescriptor.payloadType; receivePlan.projectedExtractOp = projectedDescriptor.extractOp; receivePlan.projectedLayout = projectedDescriptor.layout; - receivePlan.projectedDescriptor = projectedDescriptor; auto groupIt = llvm::find_if(fanoutPlan.projectedSendGroups, [&](const ProjectedScalarSendGroup& group) { return hasSameProjectedSendCompatibility(group.descriptor, projectedDescriptor); @@ -4779,145 +3956,6 @@ LogicalResult emitScalarSourceFanoutSends(MaterializerState& state, return success(); } - -struct GloballyOrderedScalarFanoutEvent { - size_t receivePlanIndex = 0; - int64_t minChannelId = 0; - int64_t orderKey = 0; - int32_t minSourceCoreId = 0; - int32_t minTargetCoreId = 0; -}; - -GloballyOrderedScalarFanoutEvent makeGloballyOrderedScalarFanoutEvent(size_t receivePlanIndex, - const ScalarSourceReceivePlan& plan) { - assert(!plan.messages.empty() && "expected a communication event with at least one message"); - GloballyOrderedScalarFanoutEvent event; - event.receivePlanIndex = receivePlanIndex; - event.minChannelId = plan.messages.channelIds.front(); - event.orderKey = getMinimumBlockingCommunicationOrderKey(plan.messages); - event.minSourceCoreId = plan.messages.sourceCoreIds.front(); - event.minTargetCoreId = plan.messages.targetCoreIds.front(); - - for (size_t index = 1, end = plan.messages.size(); index < end; ++index) { - event.minChannelId = std::min(event.minChannelId, plan.messages.channelIds[index]); - event.minSourceCoreId = std::min(event.minSourceCoreId, plan.messages.sourceCoreIds[index]); - event.minTargetCoreId = std::min(event.minTargetCoreId, plan.messages.targetCoreIds[index]); - } - - return event; -} - -SmallVector -collectGloballyOrderedScalarFanoutEvents(const ScalarSourceFanoutPlan& plan) { - SmallVector events; - events.reserve(plan.receivePlans.size()); - - for (auto [index, receivePlan] : llvm::enumerate(plan.receivePlans)) - if (!receivePlan.messages.empty()) - events.push_back(makeGloballyOrderedScalarFanoutEvent(index, receivePlan)); - - llvm::sort(events, [](const GloballyOrderedScalarFanoutEvent& lhs, - const GloballyOrderedScalarFanoutEvent& rhs) { - if (lhs.orderKey != rhs.orderKey) - return lhs.orderKey < rhs.orderKey; - if (lhs.minChannelId != rhs.minChannelId) - return lhs.minChannelId < rhs.minChannelId; - if (lhs.minSourceCoreId != rhs.minSourceCoreId) - return lhs.minSourceCoreId < rhs.minSourceCoreId; - if (lhs.minTargetCoreId != rhs.minTargetCoreId) - return lhs.minTargetCoreId < rhs.minTargetCoreId; - return lhs.receivePlanIndex < rhs.receivePlanIndex; - }); - - return events; -} - -LogicalResult emitGloballyOrderedScalarFanoutSend(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - const ScalarSourceReceivePlan& plan, - Location loc) { - if (plan.projectedDescriptor) - return appendProjectedScalarSendLoop(state, sourceClass, payload, *plan.projectedDescriptor, plan.messages, loc); - - return appendSend(state, sourceClass, payload, plan.messages, loc); -} - -bool isMaterializedBlockingCommunication(Operation& op) { - return isa(&op) || op.hasAttr(kRaptorMinChannelIdAttr) - || op.hasAttr(kRaptorCommOrderAttr); -} - -bool payloadIsAvailableOnlyAfterPriorCommunication(Value payload, MaterializedClass& sourceClass) { - Operation* lowerBound = getPayloadDefiningOpInClassBlock(payload, sourceClass); - if (!lowerBound) - return false; - - bool sawPriorCommunication = false; - Operation* terminator = sourceClass.body->getTerminator(); - for (Operation& op : *sourceClass.body) { - if (&op == terminator) - break; - - if (&op == lowerBound) - return sawPriorCommunication || isMaterializedBlockingCommunication(op); - - if (isMaterializedBlockingCommunication(op)) - sawPriorCommunication = true; - } - - return sawPriorCommunication; -} - -bool shouldPlaceMatchingScalarFanoutReceiveLate(MaterializedClass& sourceClass, - Value payload, - const MessageVector& messages) { - if (payloadIsAvailableOnlyAfterPriorCommunication(payload, sourceClass)) - return true; - - for (size_t index = 0, end = messages.size(); index < end; ++index) - if (shouldDelayScalarSendUntilAfterReceives( - payload, messages.sourceCoreIds[index], messages.targetCoreIds[index])) - return true; - return false; -} - -LogicalResult emitGloballyOrderedScalarSourceFanout(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Value payload, - const ScalarSourceFanoutPlan& plan, - Location loc) { - SmallVector events = collectGloballyOrderedScalarFanoutEvents(plan); - - for (const GloballyOrderedScalarFanoutEvent& event : events) { - const ScalarSourceReceivePlan& planEntry = plan.receivePlans[event.receivePlanIndex]; - MaterializedClass& targetClass = state.classes[planEntry.targetClass]; - - if (failed(emitGloballyOrderedScalarFanoutSend(state, sourceClass, payload, planEntry, loc))) - return failure(); - - if (!targetClass.isBatch && !planEntry.projectedExtractOp) { - recordPendingScalarReceive(state, targetClass.id, keys, planEntry.receiveType, planEntry.messages, loc); - continue; - } - - bool lateReceive = shouldPlaceMatchingScalarFanoutReceiveLate(sourceClass, payload, planEntry.messages); - Value received = appendReceive(state, targetClass, planEntry.receiveType, planEntry.messages, loc, lateReceive); - - if (planEntry.projectedExtractOp) { - state.projectedExtractReplacements[planEntry.projectedExtractOp][planEntry.targetClass] = - ProjectedExtractReplacement {received, planEntry.projectedLayout}; - continue; - } - - for (ProducerKey key : keys) - state.availableValues.record(key, targetClass.id, received); - } - - return success(); -} - LogicalResult emitScalarSourceCommunication( MaterializerState& state, MaterializedClass& sourceClass, ArrayRef keys, Value payload, Location loc) { assert(!sourceClass.isBatch && "scalar-source communication expects a scalar source class"); @@ -4929,9 +3967,6 @@ LogicalResult emitScalarSourceCommunication( auto fanoutPlan = buildScalarSourceFanoutPlan(state, sourceClass, keys, destinationClasses, payload); if (failed(fanoutPlan)) return failure(); - if (pimMaterializeScalarFanoutGlobalOrder) - return emitGloballyOrderedScalarSourceFanout(state, sourceClass, keys, payload, *fanoutPlan, loc); - if (failed(emitScalarSourceFanoutSends(state, sourceClass, payload, *fanoutPlan, loc))) return failure(); @@ -4953,112 +3988,6 @@ LogicalResult emitScalarSourceCommunication( return success(); } -FailureOr emitOrderedBatchToBatchCommunication(MaterializerState& state, - MaterializedClass& sourceClass, - MaterializedClass& targetClass, - Value payload, - const MessageVector& messages, - Location loc) { - assert(sourceClass.isBatch && targetClass.isBatch && "ordered batch communication expects two batch classes"); - if (failed(messages.verify(sourceClass.op))) - return failure(); - - auto payloadType = dyn_cast(payload.getType()); - if (!payloadType || !payloadType.hasStaticShape()) - return sourceClass.op->emitError("ordered batch communication expects a static ranked tensor payload"); - - auto makeEmpty = [&](MaterializedClass& materializedClass) -> Value { - return tensor::EmptyOp::create( - state.rewriter, loc, payloadType.getShape(), payloadType.getElementType()) - .getResult(); - }; - - setInsertionPointForEarlyCommunication(state, sourceClass); - Value sendChannelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc); - Value sendSourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc); - Value sendTargetCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.targetCoreIds, loc); - Value sendEarlyCond = arith::CmpIOp::create( - state.rewriter, - loc, - arith::CmpIPredicate::sle, - sendSourceCoreId, - sendTargetCoreId) - .getResult(); - auto earlySendIf = scf::IfOp::create(state.rewriter, loc, TypeRange {}, sendEarlyCond, /*withElseRegion=*/false); - state.rewriter.setInsertionPoint(earlySendIf.thenBlock()->getTerminator()); - auto earlySend = SpatChannelSendOp::create( - state.rewriter, loc, sendChannelId, sendSourceCoreId, sendTargetCoreId, payload); - markScalarCommunication( - earlySend.getOperation(), getMinimumChannelId(messages.channelIds), "emitOrderedBatchToBatchCommunication.earlySend"); - - setInsertionPointForLateCommunication(state, sourceClass); - Value sendLateCond = arith::CmpIOp::create( - state.rewriter, - loc, - arith::CmpIPredicate::sgt, - sendSourceCoreId, - sendTargetCoreId) - .getResult(); - auto lateSendIf = scf::IfOp::create(state.rewriter, loc, TypeRange {}, sendLateCond, /*withElseRegion=*/false); - rememberLateCommunicationOp(state, sourceClass, lateSendIf.getOperation()); - state.rewriter.setInsertionPoint(lateSendIf.thenBlock()->getTerminator()); - auto lateSend = SpatChannelSendOp::create( - state.rewriter, loc, sendChannelId, sendSourceCoreId, sendTargetCoreId, payload); - markScalarCommunication( - lateSend.getOperation(), getMinimumChannelId(messages.channelIds), "emitOrderedBatchToBatchCommunication.lateSend"); - - setInsertionPointForEarlyCommunication(state, targetClass); - Value recvChannelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc); - Value recvSourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc); - Value recvTargetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc); - Value recvEarlyCond = arith::CmpIOp::create( - state.rewriter, - loc, - arith::CmpIPredicate::sle, - recvSourceCoreId, - recvTargetCoreId) - .getResult(); - auto earlyReceiveIf = scf::IfOp::create( - state.rewriter, loc, TypeRange {payload.getType()}, recvEarlyCond, /*withElseRegion=*/true); - Operation* earlyThenYield = earlyReceiveIf.thenBlock()->getTerminator(); - state.rewriter.setInsertionPoint(earlyThenYield); - auto earlyReceive = SpatChannelReceiveOp::create( - state.rewriter, loc, payload.getType(), recvChannelId, recvSourceCoreId, recvTargetCoreId); - markScalarCommunication( - earlyReceive.getOperation(), getMinimumChannelId(messages.channelIds), "emitOrderedBatchToBatchCommunication.earlyReceive"); - Value earlyReceived = earlyReceive.getOutput(); - state.rewriter.modifyOpInPlace(earlyThenYield, [&] { earlyThenYield->setOperands(ValueRange {earlyReceived}); }); - Operation* earlyElseYield = earlyReceiveIf.elseBlock()->getTerminator(); - state.rewriter.setInsertionPoint(earlyElseYield); - Value empty = makeEmpty(targetClass); - state.rewriter.modifyOpInPlace(earlyElseYield, [&] { earlyElseYield->setOperands(ValueRange {empty}); }); - - setInsertionPointForLateCommunication(state, targetClass); - Value recvLateCond = arith::CmpIOp::create( - state.rewriter, - loc, - arith::CmpIPredicate::sgt, - recvSourceCoreId, - recvTargetCoreId) - .getResult(); - auto lateReceiveIf = scf::IfOp::create( - state.rewriter, loc, TypeRange {payload.getType()}, recvLateCond, /*withElseRegion=*/true); - rememberLateCommunicationOp(state, targetClass, lateReceiveIf.getOperation()); - Operation* lateThenYield = lateReceiveIf.thenBlock()->getTerminator(); - state.rewriter.setInsertionPoint(lateThenYield); - auto lateReceive = SpatChannelReceiveOp::create( - state.rewriter, loc, payload.getType(), recvChannelId, recvSourceCoreId, recvTargetCoreId); - markScalarCommunication( - lateReceive.getOperation(), getMinimumChannelId(messages.channelIds), "emitOrderedBatchToBatchCommunication.lateReceive"); - Value lateReceived = lateReceive.getOutput(); - state.rewriter.modifyOpInPlace(lateThenYield, [&] { lateThenYield->setOperands(ValueRange {lateReceived}); }); - Operation* lateElseYield = lateReceiveIf.elseBlock()->getTerminator(); - state.rewriter.modifyOpInPlace( - lateElseYield, [&] { lateElseYield->setOperands(ValueRange {earlyReceiveIf.getResult(0)}); }); - - return lateReceiveIf.getResult(0); -} - LogicalResult emitClassToClassCommunication(MaterializerState& state, MaterializedClass& sourceClass, MaterializedClass& targetClass, @@ -5076,13 +4005,16 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, if (!targetClass.isBatch) { MessageVector messages; - messages.channelIds.reserve(sourceClass.cpus.size()); - messages.sourceCoreIds.reserve(sourceClass.cpus.size()); - messages.targetCoreIds.reserve(sourceClass.cpus.size()); + messages.channelIds.reserve(keys.size()); + messages.sourceCoreIds.reserve(keys.size()); + messages.targetCoreIds.reserve(keys.size()); auto targetCpu = getCheckedCoreId(targetClass.op, targetClass.cpus.front(), "batch-to-scalar target core id"); if (failed(targetCpu)) return failure(); + if (keys.size() != sourceClass.cpus.size()) + return sourceClass.op->emitError( + "batch-to-scalar communication expects one producer key per source lane"); for (CpuId sourceCpu : sourceClass.cpus) { auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch-to-scalar source core id"); if (failed(checkedSourceCpu)) @@ -5121,13 +4053,12 @@ LogicalResult emitClassToClassCommunication(MaterializerState& state, messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); } - FailureOr received = - emitOrderedBatchToBatchCommunication(state, sourceClass, targetClass, payload, messages, loc); - if (failed(received)) + if (failed(appendSend(state, sourceClass, payload, messages, loc))) return failure(); + Value received = appendReceive(state, targetClass, payload.getType(), messages, loc); for (ProducerKey key : keys) - state.availableValues.record(key, targetClass.id, *received); + state.availableValues.record(key, targetClass.id, received); return success(); } @@ -5143,9 +4074,7 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val : StringRef("")); unsigned resultIndex = resultIt->second; - if (payload.getType() != originalOutput.getType()) - return sourceClass.op->emitError("cannot set host output from fragment payload without projection") - << " payloadType=" << payload.getType() << " outputType=" << originalOutput.getType(); + state.hostReplacements[originalOutput] = sourceClass.op->getResult(resultIndex); if (!sourceClass.isBatch) { auto yieldOp = dyn_cast(sourceClass.body->getTerminator()); @@ -5153,9 +4082,11 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val return sourceClass.op->emitError("expected spat.yield terminator in materialized compute"); if (resultIndex >= yieldOp.getNumOperands()) return sourceClass.op->emitError("host result index out of range for materialized compute"); + if (payload.getType() != originalOutput.getType()) + return sourceClass.op->emitError("cannot set scalar host output from fragment payload") + << " payloadType=" << payload.getType() << " outputType=" << originalOutput.getType(); state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(resultIndex, payload); }); - state.hostReplacements[originalOutput] = sourceClass.op->getResult(resultIndex); return success(); } @@ -5177,830 +4108,15 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val return batch.emitOpError("expected compute_batch output block argument while materializing batch output"); state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); - createDim0ParallelInsertSlice(state, payload.getLoc(), payload, *outputArg, *laneArg); - state.hostReplacements[originalOutput] = sourceClass.op->getResult(resultIndex); return success(); } -FailureOr -getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex); - -LogicalResult emitProjectedBatchHostOutput(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Value originalOutput, - Value payload, - Location loc) { - if (!sourceClass.isBatch) - return sourceClass.op->emitError("projected batch host publication expects a batch owner class"); - auto batch = cast(sourceClass.op); - - auto ownerIt = sourceClass.hostOutputToResultIndex.find(originalOutput); - if (ownerIt == sourceClass.hostOutputToResultIndex.end()) - return sourceClass.op->emitError("missing host result slot for projected batch output"); - - auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); - auto originalResult = dyn_cast(originalOutput); - if (!sourceBatch || sourceBatch.getNumResults() == 0 || !originalResult) - return sourceClass.op->emitError("projected batch host publication expects a resultful compute_batch output"); - - FailureOr projection = - getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); - if (failed(projection)) - return sourceBatch.emitOpError("failed to recover batch host projection for publication"); - - auto sourceLaneArg = sourceBatch.getLaneArgument(); - if (!sourceLaneArg) - return sourceBatch.emitOpError("missing source compute_batch lane argument for host projection"); - - // The projection coordinates are part of the source batch publication. - // Build any affine/index helper ops in the source batch body, not at the - // caller's current insertion point. Otherwise a scalar host-owner body may - // accidentally capture the source scheduled_compute_batch lane argument. - OpBuilder::InsertionGuard projectionGuard(state.rewriter); - state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - - FailureOr projectionLaneValue = createProjectionLaneValueForKeys(state, sourceClass, keys, loc); - if (failed(projectionLaneValue)) - return failure(); - - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(projection->getMixedOffsets().size()); - sizes.reserve(projection->getMixedSizes().size()); - strides.reserve(projection->getMixedStrides().size()); - - for (OpFoldResult offset : projection->getMixedOffsets()) { - FailureOr remapped = - remapProjectionIndexLike(state, sourceClass.op, offset, *sourceLaneArg, *projectionLaneValue, loc); - if (failed(remapped)) - return sourceClass.op->emitError("failed to remap projected batch host offsets"); - offsets.push_back(*remapped); - } - for (OpFoldResult size : projection->getMixedSizes()) { - FailureOr remapped = - remapProjectionIndexLike(state, sourceClass.op, size, *sourceLaneArg, *projectionLaneValue, loc); - if (failed(remapped)) - return sourceClass.op->emitError("failed to remap projected batch host sizes"); - sizes.push_back(*remapped); - } - for (OpFoldResult stride : projection->getMixedStrides()) { - FailureOr remapped = - remapProjectionIndexLike(state, sourceClass.op, stride, *sourceLaneArg, *projectionLaneValue, loc); - if (failed(remapped)) - return sourceClass.op->emitError("failed to remap projected batch host strides"); - strides.push_back(*remapped); - } - - auto inParallelOp = dyn_cast(sourceClass.body->getTerminator()); - if (!inParallelOp) - return sourceClass.op->emitError("expected spat.in_parallel terminator in materialized compute_batch"); - - auto outputArg = batch.getOutputArgument(ownerIt->second); - if (!outputArg) - return batch.emitOpError("missing host output block argument for projected batch publication"); - - state.hostReplacements[originalOutput] = sourceClass.op->getResult(ownerIt->second); - state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); - tensor::ParallelInsertSliceOp::create(state.rewriter, loc, payload, *outputArg, offsets, sizes, strides); - return success(); -} - -FailureOr evaluateProjectionIndexLike(OpFoldResult value, Value laneArg, uint32_t lane); - -FailureOr evaluateProjectionIndexLike(Value value, Value laneArg, uint32_t lane) { - if (value == laneArg) - return static_cast(lane); - - if (std::optional constant = matchConstantIndexValue(value)) - return *constant; - - auto affineApply = value.getDefiningOp(); - if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) - return failure(); - - SmallVector operands; - operands.reserve(affineApply.getMapOperands().size()); - for (Value operand : affineApply.getMapOperands()) { - FailureOr evaluated = evaluateProjectionIndexLike(operand, laneArg, lane); - if (failed(evaluated)) - return failure(); - operands.push_back(IntegerAttr::get(IndexType::get(value.getContext()), *evaluated)); - } - - SmallVector results; - if (failed(affineApply.getAffineMap().constantFold(operands, results)) || results.size() != 1) - return failure(); - - auto intAttr = dyn_cast(results.front()); - if (!intAttr) - return failure(); - return intAttr.getInt(); -} - -FailureOr evaluateProjectionIndexLike(OpFoldResult value, Value laneArg, uint32_t lane) { - if (auto attr = llvm::dyn_cast(value)) { - auto intAttr = dyn_cast(attr); - if (!intAttr) - return failure(); - return intAttr.getInt(); - } - return evaluateProjectionIndexLike(llvm::cast(value), laneArg, lane); -} - -FailureOr -getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex) { - auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); - if (!inParallel) - return failure(); - - auto firstOutputArg = batch.getOutputArgument(0); - if (!firstOutputArg) - return failure(); - - for (Operation& op : inParallel.getRegion().front()) { - auto insert = dyn_cast(&op); - if (!insert) - continue; - - auto outputArg = dyn_cast(insert.getDest()); - if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) - continue; - - unsigned candidateIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); - if (candidateIndex == resultIndex) - return insert; - } - - return failure(); -} - -FailureOr> -evaluateStaticProjectionIndices(ArrayRef values, Value laneArg, uint32_t lane) { - SmallVector evaluated; - evaluated.reserve(values.size()); - for (OpFoldResult value : values) { - FailureOr index = evaluateProjectionIndexLike(value, laneArg, lane); - if (failed(index)) - return failure(); - evaluated.push_back(*index); - } - return evaluated; -} - - -bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch, - const AffineProjectedInputSliceMatch& match, - ProducerKey producer, - uint32_t consumerLane) { - auto producerBatch = dyn_cast_or_null(producer.instance.op); - if (!producerBatch) - return true; - - FailureOr producerProjection = - getBatchResultProjectionInsert(producerBatch, producer.resultIndex); - if (failed(producerProjection)) - return true; - - std::optional producerLaneArg = producerBatch.getLaneArgument(); - std::optional consumerLaneArg = consumerBatch.getLaneArgument(); - if (!producerLaneArg || !consumerLaneArg) - return false; - - SmallVector consumerSizes(match.fragmentShape.begin(), match.fragmentShape.end()); - SmallVector loopIterationIndices(match.loops.size(), 0); - - const auto consumerSliceFitsOneProducerFragment = [&]() -> bool { - SmallVector consumerOffsets; - consumerOffsets.reserve(match.offsets.size()); - for (OpFoldResult offset : match.offsets) { - FailureOr evaluated = - evaluateProjectedOffsetValue(offset, *consumerLaneArg, consumerLane, match.loops, loopIterationIndices); - if (failed(evaluated)) - return false; - consumerOffsets.push_back(*evaluated); - } - - uint32_t producerLaneEnd = producer.instance.laneStart + producer.instance.laneCount; - for (uint32_t producerLane = producer.instance.laneStart; producerLane < producerLaneEnd; ++producerLane) { - FailureOr> producerOffsets = - evaluateStaticProjectionIndices(producerProjection->getMixedOffsets(), *producerLaneArg, producerLane); - FailureOr> producerSizes = - evaluateStaticProjectionIndices(producerProjection->getMixedSizes(), *producerLaneArg, producerLane); - FailureOr> producerStrides = - evaluateStaticProjectionIndices(producerProjection->getMixedStrides(), *producerLaneArg, producerLane); - if (failed(producerOffsets) || failed(producerSizes) || failed(producerStrides)) - return false; - if (!areAllUnitStrides(*producerStrides)) - return false; - if (isStaticSliceContainedIn(consumerOffsets, consumerSizes, *producerOffsets, *producerSizes)) - return true; - } - - return false; - }; - - if (match.loops.empty()) - return consumerSliceFitsOneProducerFragment(); - - const auto recurse = [&](auto&& self, size_t loopIndex) -> bool { - if (loopIndex == match.loops.size()) - return consumerSliceFitsOneProducerFragment(); - - for (int64_t iteration = 0; iteration < match.loops[loopIndex].tripCount; ++iteration) { - loopIterationIndices[loopIndex] = iteration; - if (!self(self, loopIndex + 1)) - return false; - } - return true; - }; - - return recurse(recurse, 0); -} - -LogicalResult insertProjectedBatchHostFragment(MaterializerState& state, - MaterializedClass& ownerClass, - Value originalOutput, - uint32_t lane, - Value payload) { - if (ownerClass.isBatch) - return ownerClass.op->emitError("projected batch host fallback expects a scalar owner class"); - - auto ownerIt = ownerClass.hostOutputToResultIndex.find(originalOutput); - if (ownerIt == ownerClass.hostOutputToResultIndex.end()) - return ownerClass.op->emitError("missing host result slot for projected batch host fragment"); - - auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); - auto originalResult = dyn_cast(originalOutput); - if (!sourceBatch || sourceBatch.getNumResults() == 0 || !originalResult) - return ownerClass.op->emitError("projected batch host fallback expects a resultful compute_batch output"); - - FailureOr projection = - getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); - if (failed(projection)) - return sourceBatch.emitOpError("failed to recover batch host projection for materialization"); - - auto laneArg = sourceBatch.getLaneArgument(); - if (!laneArg) - return sourceBatch.emitOpError("missing compute_batch lane argument for host projection"); - - FailureOr> offsets = - evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, lane); - FailureOr> sizes = - evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, lane); - FailureOr> strides = - evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, lane); - if (failed(offsets) || failed(sizes) || failed(strides)) - return ownerClass.op->emitError("failed to evaluate batch host projection coordinates"); - - auto yieldOp = dyn_cast(ownerClass.body->getTerminator()); - if (!yieldOp) - return ownerClass.op->emitError("expected spat.yield terminator in scalar host owner"); - - unsigned hostResultIndex = ownerIt->second; - if (hostResultIndex >= yieldOp.getNumOperands()) - return ownerClass.op->emitError("host result index out of range for projected batch host fragment"); - if (yieldOp.getOperand(hostResultIndex).getType() != originalOutput.getType()) - return ownerClass.op->emitError("projected batch host fragment expected a full host accumulator tensor") - << " accumulatorType=" << yieldOp.getOperand(hostResultIndex).getType() - << " outputType=" << originalOutput.getType(); - - state.rewriter.setInsertionPoint(yieldOp); - Value updated = tensor::InsertSliceOp::create(state.rewriter, - payload.getLoc(), - payload, - yieldOp.getOperand(hostResultIndex), - ValueRange {}, - ValueRange {}, - ValueRange {}, - *offsets, - *sizes, - *strides) - .getResult(); - state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(hostResultIndex, updated); }); - state.hostReplacements[originalOutput] = ownerClass.op->getResult(hostResultIndex); - return success(); -} - - -LogicalResult emitProjectedBatchHostReceiveInsertLoop(MaterializerState& state, - MaterializedClass& ownerClass, - Value originalOutput, - ArrayRef keys, - RankedTensorType fragmentType, - const MessageVector& messages, - Location loc) { - if (ownerClass.isBatch) - return ownerClass.op->emitError("projected batch host receive loop expects a scalar owner class"); - if (keys.empty()) - return success(); - if (keys.size() != messages.size()) - return ownerClass.op->emitError("projected batch host receive loop message metadata is inconsistent"); - - auto ownerIt = ownerClass.hostOutputToResultIndex.find(originalOutput); - if (ownerIt == ownerClass.hostOutputToResultIndex.end()) - return ownerClass.op->emitError("missing host result slot for projected batch host receive loop"); - - auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); - auto originalResult = dyn_cast(originalOutput); - if (!sourceBatch || sourceBatch.getNumResults() == 0 || !originalResult) - return ownerClass.op->emitError("projected batch host receive loop expects a resultful compute_batch output"); - - FailureOr projection = - getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); - if (failed(projection)) - return sourceBatch.emitOpError("failed to recover batch host projection for receive loop"); - - auto laneArg = sourceBatch.getLaneArgument(); - if (!laneArg) - return sourceBatch.emitOpError("missing compute_batch lane argument for projected host receive loop"); - - auto yieldOp = dyn_cast(ownerClass.body->getTerminator()); - if (!yieldOp) - return ownerClass.op->emitError("expected spat.yield terminator in scalar host owner"); - - unsigned hostResultIndex = ownerIt->second; - if (hostResultIndex >= yieldOp.getNumOperands()) - return ownerClass.op->emitError("host result index out of range for projected batch host receive loop"); - if (yieldOp.getOperand(hostResultIndex).getType() != originalOutput.getType()) - return ownerClass.op->emitError("projected batch host receive loop expected a full host accumulator tensor") - << " accumulatorType=" << yieldOp.getOperand(hostResultIndex).getType() - << " outputType=" << originalOutput.getType(); - - unsigned rank = projection->getMixedOffsets().size(); - SmallVector, 4> offsetsByDim(rank); - SmallVector, 4> sizesByDim(rank); - SmallVector, 4> stridesByDim(rank); - for (ProducerKey key : keys) { - if (key.instance.op != originalOutput.getDefiningOp() || key.resultIndex != originalResult.getResultNumber() - || key.instance.laneCount != 1) - return ownerClass.op->emitError("projected batch host receive loop expects one-lane fragments from one output"); - - FailureOr> offsets = - evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, key.instance.laneStart); - FailureOr> sizes = - evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, key.instance.laneStart); - FailureOr> strides = - evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, key.instance.laneStart); - if (failed(offsets) || failed(sizes) || failed(strides)) - return ownerClass.op->emitError("failed to evaluate projected batch host receive loop coordinates"); - if (offsets->size() != rank || sizes->size() != rank || strides->size() != rank) - return ownerClass.op->emitError("projected batch host receive loop coordinate rank mismatch"); - - for (unsigned dim = 0; dim < rank; ++dim) { - offsetsByDim[dim].push_back((*offsets)[dim]); - sizesByDim[dim].push_back((*sizes)[dim]); - stridesByDim[dim].push_back((*strides)[dim]); - } - } - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, static_cast(keys.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, 1); - - state.rewriter.setInsertionPoint(yieldOp); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {yieldOp.getOperand(hostResultIndex)}, - [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - Value channelId = createIndexedChannelId(state, ownerClass.op, messages, flatIndex, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, ownerClass.op, messages, flatIndex, loc); - Value targetCoreId = createIndexedTargetCoreId(state, ownerClass.op, messages, flatIndex, loc); - Value fragment = SpatChannelReceiveOp::create( - state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId) - .getOutput(); - - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(rank); - sizes.reserve(rank); - strides.reserve(rank); - for (unsigned dim = 0; dim < rank; ++dim) { - offsets.push_back(createIndexedOrStaticIndex(state, ownerClass.op, offsetsByDim[dim], flatIndex, loc)); - sizes.push_back(createIndexedOrStaticIndex(state, ownerClass.op, sizesByDim[dim], flatIndex, loc)); - strides.push_back(createIndexedOrStaticIndex(state, ownerClass.op, stridesByDim[dim], flatIndex, loc)); - } - - Value updated = tensor::InsertSliceOp::create(state.rewriter, loc, fragment, iterArgs.front(), offsets, sizes, strides) - .getResult(); - yielded.push_back(updated); - return success(); - }); - if (failed(loop)) - return failure(); - markScalarCommunication( - loop->loop.getOperation(), getMinimumChannelId(messages.channelIds), "emitProjectedBatchHostReceiveInsertLoop"); - - state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(hostResultIndex, loop->results.front()); }); - state.hostReplacements[originalOutput] = ownerClass.op->getResult(hostResultIndex); - return success(); -} - -std::optional tryEmitProjectedBatchHostReceiveInsertLoop(MaterializerState& state, - MaterializedClass& ownerClass, - Value originalOutput, - ArrayRef keys, - Location loc) { - if (keys.empty()) - return success(); - - WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(keys.front(), ownerClass.id); - ArrayRef runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey); - for (size_t runIndex : runIndices) { - PackedScalarRunValue& run = state.availableValues.getPackedRun(runIndex); - if (run.kind != PackedScalarRunKind::DeferredReceive) - continue; - SmallVector runKeys = flattenPackedScalarRunKeys(run); - if (!llvm::equal(runKeys, keys)) - continue; - return emitProjectedBatchHostReceiveInsertLoop( - state, ownerClass, originalOutput, runKeys, run.fragmentType, run.messages, loc); - } - - return std::nullopt; -} - -FailureOr getLeadingPackedFragmentType(Operation* anchor, Value payload, size_t fragmentCount) { - auto payloadType = dyn_cast(payload.getType()); - if (!payloadType || !payloadType.hasStaticShape() || payloadType.getRank() == 0) - return failure(); - if (payloadType.getDimSize(0) != static_cast(fragmentCount)) - return failure(); - - SmallVector fragmentShape(payloadType.getShape().begin(), payloadType.getShape().end()); - fragmentShape[0] = 1; - return RankedTensorType::get(fragmentShape, payloadType.getElementType(), payloadType.getEncoding()); -} - -LogicalResult emitScalarPackedProjectedHostSendLoop(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - RankedTensorType fragmentType, - const MessageVector& messages, - Location loc) { - assert(!sourceClass.isBatch && "packed projected host send loop expects a scalar source"); - assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); - - auto payloadType = dyn_cast(payload.getType()); - if (!payloadType || !payloadType.hasStaticShape() || payloadType.getRank() == 0) - return sourceClass.op->emitError("packed projected host send loop expects a static ranked payload"); - - setInsertionPointForScalarCommunication(state, sourceClass, getMinimumChannelId(messages.channelIds)); - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); - Value upperBound = - getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(messages.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); - - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {}, - [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(payloadType.getRank()); - sizes.reserve(payloadType.getRank()); - strides.reserve(payloadType.getRank()); - offsets.push_back(index); - sizes.push_back(state.rewriter.getIndexAttr(1)); - strides.push_back(state.rewriter.getIndexAttr(1)); - for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) { - offsets.push_back(state.rewriter.getIndexAttr(0)); - sizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim))); - strides.push_back(state.rewriter.getIndexAttr(1)); - } - - Value fragment = tensor::ExtractSliceOp::create( - state.rewriter, loc, fragmentType, payload, offsets, sizes, strides) - .getResult(); - Value channelId = createIndexedChannelId(state, sourceClass.op, messages, index, loc); - Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, messages, index, loc); - Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, messages, index, loc); - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, fragment); - return success(); - }); - if (failed(loop)) - return failure(); - markScalarCommunication( - loop->loop.getOperation(), getMinimumChannelId(messages.channelIds), "emitScalarPackedProjectedHostSendLoop"); - return success(); -} - -LogicalResult emitScalarPackedProjectedHostLocalInsertLoop(MaterializerState& state, - MaterializedClass& ownerClass, - ArrayRef keys, - Value payload, - Value originalOutput, - RankedTensorType fragmentType, - Location loc) { - if (ownerClass.isBatch) - return ownerClass.op->emitError("packed projected host local insert loop expects a scalar owner class"); - if (keys.empty()) - return success(); - - auto payloadType = dyn_cast(payload.getType()); - if (!payloadType || !payloadType.hasStaticShape() || payloadType.getRank() == 0) - return ownerClass.op->emitError("packed projected host local insert loop expects a static ranked payload"); - if (payloadType.getDimSize(0) != static_cast(keys.size())) - return ownerClass.op->emitError("packed projected host local insert loop payload/key count mismatch"); - - auto ownerIt = ownerClass.hostOutputToResultIndex.find(originalOutput); - if (ownerIt == ownerClass.hostOutputToResultIndex.end()) - return ownerClass.op->emitError("missing host result slot for packed projected host local insert loop"); - - auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); - auto originalResult = dyn_cast(originalOutput); - if (!sourceBatch || sourceBatch.getNumResults() == 0 || !originalResult) - return ownerClass.op->emitError("packed projected host local insert loop expects a resultful compute_batch output"); - - FailureOr projection = - getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); - if (failed(projection)) - return sourceBatch.emitOpError("failed to recover batch host projection for local insert loop"); - - auto laneArg = sourceBatch.getLaneArgument(); - if (!laneArg) - return sourceBatch.emitOpError("missing compute_batch lane argument for packed projected host local insert loop"); - - auto yieldOp = dyn_cast(ownerClass.body->getTerminator()); - if (!yieldOp) - return ownerClass.op->emitError("expected spat.yield terminator in scalar host owner"); - - unsigned hostResultIndex = ownerIt->second; - if (hostResultIndex >= yieldOp.getNumOperands()) - return ownerClass.op->emitError("host result index out of range for packed projected host local insert loop"); - if (yieldOp.getOperand(hostResultIndex).getType() != originalOutput.getType()) - return ownerClass.op->emitError("packed projected host local insert loop expected a full host accumulator tensor") - << " accumulatorType=" << yieldOp.getOperand(hostResultIndex).getType() - << " outputType=" << originalOutput.getType(); - - unsigned rank = projection->getMixedOffsets().size(); - SmallVector, 4> offsetsByDim(rank); - SmallVector, 4> sizesByDim(rank); - SmallVector, 4> stridesByDim(rank); - for (ProducerKey key : keys) { - if (key.instance.op != originalOutput.getDefiningOp() || key.resultIndex != originalResult.getResultNumber() - || key.instance.laneCount != 1) - return ownerClass.op->emitError("packed projected host local insert loop expects one-lane fragments from one output"); - - FailureOr> offsets = - evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, key.instance.laneStart); - FailureOr> sizes = - evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, key.instance.laneStart); - FailureOr> strides = - evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, key.instance.laneStart); - if (failed(offsets) || failed(sizes) || failed(strides)) - return ownerClass.op->emitError("failed to evaluate packed projected host local insert loop coordinates"); - if (offsets->size() != rank || sizes->size() != rank || strides->size() != rank) - return ownerClass.op->emitError("packed projected host local insert loop coordinate rank mismatch"); - - for (unsigned dim = 0; dim < rank; ++dim) { - offsetsByDim[dim].push_back((*offsets)[dim]); - sizesByDim[dim].push_back((*sizes)[dim]); - stridesByDim[dim].push_back((*strides)[dim]); - } - } - - Value lowerBound = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, 0); - Value upperBound = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, static_cast(keys.size())); - Value step = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, 1); - - state.rewriter.setInsertionPoint(yieldOp); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {yieldOp.getOperand(hostResultIndex)}, - [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { - SmallVector extractOffsets; - SmallVector extractSizes; - SmallVector extractStrides; - extractOffsets.reserve(payloadType.getRank()); - extractSizes.reserve(payloadType.getRank()); - extractStrides.reserve(payloadType.getRank()); - extractOffsets.push_back(flatIndex); - extractSizes.push_back(state.rewriter.getIndexAttr(1)); - extractStrides.push_back(state.rewriter.getIndexAttr(1)); - for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) { - extractOffsets.push_back(state.rewriter.getIndexAttr(0)); - extractSizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim))); - extractStrides.push_back(state.rewriter.getIndexAttr(1)); - } - - Value fragment = tensor::ExtractSliceOp::create( - state.rewriter, loc, fragmentType, payload, extractOffsets, extractSizes, extractStrides) - .getResult(); - - SmallVector offsets; - SmallVector sizes; - SmallVector strides; - offsets.reserve(rank); - sizes.reserve(rank); - strides.reserve(rank); - for (unsigned dim = 0; dim < rank; ++dim) { - offsets.push_back(createIndexedOrStaticIndex(state, ownerClass.op, offsetsByDim[dim], flatIndex, loc)); - sizes.push_back(createIndexedOrStaticIndex(state, ownerClass.op, sizesByDim[dim], flatIndex, loc)); - strides.push_back(createIndexedOrStaticIndex(state, ownerClass.op, stridesByDim[dim], flatIndex, loc)); - } - - Value updated = tensor::InsertSliceOp::create(state.rewriter, loc, fragment, iterArgs.front(), offsets, sizes, strides) - .getResult(); - yielded.push_back(updated); - return success(); - }); - if (failed(loop)) - return failure(); - - state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(hostResultIndex, loop->results.front()); }); - state.hostReplacements[originalOutput] = ownerClass.op->getResult(hostResultIndex); - return success(); -} - -std::optional tryEmitScalarPackedProjectedHostPublication(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Value payload, - Value originalOutput, - Location loc) { - if (sourceClass.isBatch || keys.size() <= 1) - return std::nullopt; - - auto ownerIt = state.hostOutputOwners.find(originalOutput); - if (ownerIt == state.hostOutputOwners.end()) - return sourceClass.op->emitError("missing host owner for projected batch output"); - - MaterializedClass& ownerClass = state.classes[ownerIt->second]; - if (ownerClass.isBatch) - return ownerClass.op->emitError( - "projected batch host output reached a batch owner without an explicit batch publication path"); - FailureOr fragmentType = getLeadingPackedFragmentType(sourceClass.op, payload, keys.size()); - if (failed(fragmentType)) - return std::nullopt; - - if (ownerClass.id == sourceClass.id) - return emitScalarPackedProjectedHostLocalInsertLoop( - state, ownerClass, keys, payload, originalOutput, *fragmentType, loc); - - auto sourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "projected host source core id"); - auto targetCpu = getCheckedCoreId(ownerClass.op, ownerClass.cpus.front(), "projected host target core id"); - if (failed(sourceCpu) || failed(targetCpu)) - return failure(); - - MessageVector messages; - for ([[maybe_unused]] ProducerKey key : keys) - messages.append(state.nextChannelId++, *sourceCpu, *targetCpu); - - if (failed(messages.verify(sourceClass.op))) - return failure(); - - if (failed(emitScalarPackedProjectedHostSendLoop(state, sourceClass, payload, *fragmentType, messages, loc))) - return failure(); - - return emitProjectedBatchHostReceiveInsertLoop( - state, ownerClass, originalOutput, keys, *fragmentType, messages, loc); -} - -void appendPendingProjectedHostReceive(MaterializerState& state, - MaterializedClass& ownerClass, - Value originalOutput, - ProducerKey key, - RankedTensorType fragmentType, - const MessageVector& messages, - Location loc) { - assert(messages.size() == 1 && "pending projected host receive records one message at a time"); - for (PendingProjectedHostReceiveGroup& group : state.pendingProjectedHostReceives) { - if (group.originalOutput != originalOutput || group.ownerClassId != ownerClass.id || group.fragmentType != fragmentType) - continue; - group.keys.push_back(key); - group.messages.append(messages.channelIds, messages.sourceCoreIds, messages.targetCoreIds); - return; - } - - PendingProjectedHostReceiveGroup group { - originalOutput, - ownerClass.id, - fragmentType, - SmallVector{key}, - MessageVector{}, - loc - }; - group.messages.append(messages.channelIds, messages.sourceCoreIds, messages.targetCoreIds); - state.pendingProjectedHostReceives.push_back(std::move(group)); -} - -LogicalResult flushPendingProjectedHostReceives(MaterializerState& state) { - for (PendingProjectedHostReceiveGroup& group : state.pendingProjectedHostReceives) { - if (group.ownerClassId >= state.classes.size()) - return state.func.emitError("pending projected host receive has invalid owner class"); - MaterializedClass& ownerClass = state.classes[group.ownerClassId]; - if (failed(group.messages.verify(ownerClass.op))) - return failure(); - if (group.keys.empty()) - continue; - if (failed(emitProjectedBatchHostReceiveInsertLoop( - state, ownerClass, group.originalOutput, group.keys, group.fragmentType, group.messages, group.loc))) - return failure(); - } - state.pendingProjectedHostReceives.clear(); - return success(); -} - -LogicalResult emitProjectedBatchHostFragment(MaterializerState& state, - MaterializedClass& sourceClass, - ProducerKey key, - Value payload, - Value originalOutput, - Location loc) { - auto ownerIt = state.hostOutputOwners.find(originalOutput); - if (ownerIt == state.hostOutputOwners.end()) - return sourceClass.op->emitError("missing host owner for projected batch output"); - - MaterializedClass& ownerClass = state.classes[ownerIt->second]; - Value ownerPayload = payload; - if (sourceClass.id != ownerClass.id) { - if (ownerClass.isBatch) { - return ownerClass.op->emitError( - "projected batch host fragment reached a batch owner without an explicit batch publication path"); - } - - MessageVector messages; - auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "projected host source core id"); - auto checkedTargetCpu = getCheckedCoreId(ownerClass.op, ownerClass.cpus.front(), "projected host target core id"); - if (failed(checkedTargetCpu)) - return failure(); - if (!sourceClass.isBatch) { - if (failed(checkedSourceCpu)) - return failure(); - messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); - - if (failed(appendSend(state, sourceClass, payload, messages, loc))) - return failure(); - - auto fragmentType = dyn_cast(payload.getType()); - if (!fragmentType) - return sourceClass.op->emitError("projected terminal batch host fragment expects ranked tensor payload"); - appendPendingProjectedHostReceive(state, ownerClass, originalOutput, key, fragmentType, messages, loc); - return success(); - } - else { - ComputeInstance scheduledInstance = getScheduledChunkForLogicalInstance(state, key.instance); - auto sourceCpuIt = state.schedule.computeToCpuMap.find(scheduledInstance); - if (sourceCpuIt == state.schedule.computeToCpuMap.end()) - return sourceClass.op->emitError("missing CPU assignment for projected batch host source"); - - auto localLaneIt = sourceClass.cpuToLane.find(sourceCpuIt->second); - if (localLaneIt == sourceClass.cpuToLane.end()) - return sourceClass.op->emitError("missing local batch lane for projected batch host source"); - - if (failed(checkedSourceCpu = getCheckedCoreId(sourceClass.op, - sourceCpuIt->second, - "projected host source core id"))) - return failure(); - messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); - - auto batch = cast(sourceClass.op); - auto laneArg = batch.getLaneArgument(); - if (!laneArg) - return batch.emitOpError("missing lane argument for projected batch host source"); - - state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - Value localLane = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, localLaneIt->second); - Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.channelIds.front()); - Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.sourceCoreIds.front()); - Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.targetCoreIds.front()); - Value isSourceLane = arith::CmpIOp::create(state.rewriter, loc, arith::CmpIPredicate::eq, *laneArg, localLane); - auto ifOp = scf::IfOp::create(state.rewriter, loc, TypeRange {}, isSourceLane, /*withElseRegion=*/false); - state.rewriter.setInsertionPoint(ifOp.thenBlock()->getTerminator()); - SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); - ownerPayload = appendReceive(state, ownerClass, payload.getType(), messages, loc); - } - } - - return insertProjectedBatchHostFragment(state, ownerClass, originalOutput, key.instance.laneStart, ownerPayload); -} - LogicalResult emitHostCommunication(MaterializerState& state, MaterializedClass& sourceClass, Value payload, Value originalOutput) { if (!hasLiveExternalUseCached(state, originalOutput)) return success(); - if (isProjectedTerminalBatchHostOutput(originalOutput, state.oldComputeOps)) - return sourceClass.op->emitError("cannot set projected terminal batch host output through the generic host path"); - auto ownerIt = state.hostOutputOwners.find(originalOutput); if (ownerIt == state.hostOutputOwners.end()) return sourceClass.op->emitError("missing host owner for live external output"); @@ -6009,10 +4125,42 @@ emitHostCommunication(MaterializerState& state, MaterializedClass& sourceClass, if (sourceClass.id == ownerClass.id) return setHostOutputValue(state, ownerClass, originalOutput, payload); + // Keep the old deadlock-free communication discipline: only scalar-to-scalar + // host-owner forwarding is introduced here. Batch host publication remains on + // the owning batch path; projected terminal batch publication must use the + // explicit projected whole-batch path instead of generic host forwarding. + if (sourceClass.isBatch && ownerClass.isBatch) { + if (sourceClass.cpus.size() != ownerClass.cpus.size()) + return sourceClass.op->emitError("batch host publication requires batch source/owner classes of equal size"); + + MessageVector messages; + messages.channelIds.reserve(sourceClass.cpus.size()); + messages.sourceCoreIds.reserve(sourceClass.cpus.size()); + messages.targetCoreIds.reserve(ownerClass.cpus.size()); + + for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "host batch source core id"); + if (failed(checkedSourceCpu)) + return failure(); + auto checkedTargetCpu = getCheckedCoreId(ownerClass.op, ownerClass.cpus[lane], "host batch owner core id"); + if (failed(checkedTargetCpu)) + return failure(); + messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); + } + + if (failed(appendSend(state, sourceClass, payload, messages, payload.getLoc()))) + return failure(); + Value ownerPayload = appendReceive(state, ownerClass, payload.getType(), messages, payload.getLoc()); + return setHostOutputValue(state, ownerClass, originalOutput, ownerPayload); + } + if (sourceClass.isBatch) - return sourceClass.op->emitError("batch host publication must be routed through a projection-aware or owning path"); + return sourceClass.op->emitError("batch host publication must be routed through the owning/projection-aware path"); if (ownerClass.isBatch) return ownerClass.op->emitError("generic host publication does not support batch host owners"); + if (payload.getType() != originalOutput.getType()) + return sourceClass.op->emitError("cannot forward fragment payload to scalar host owner") + << " payloadType=" << payload.getType() << " outputType=" << originalOutput.getType(); MessageVector messages; auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "host source core id"); @@ -6028,11 +4176,11 @@ emitHostCommunication(MaterializerState& state, MaterializedClass& sourceClass, } LogicalResult emitOutputFanout(MaterializerState& state, - MaterializedClass& sourceClass, - ArrayRef keys, - Value payload, - Value originalOutput, - Location loc) { + MaterializedClass& sourceClass, + ArrayRef keys, + Value payload, + Value originalOutput, + Location loc) { if (keys.empty()) return success(); @@ -6040,84 +4188,32 @@ LogicalResult emitOutputFanout(MaterializerState& state, if (failed(emitScalarSourceCommunication(state, sourceClass, keys, payload, loc))) return failure(); - if (isProjectedTerminalBatchHostOutput(originalOutput, state.oldComputeOps)) { - std::optional loopedHostPublication = - tryEmitScalarPackedProjectedHostPublication(state, sourceClass, keys, payload, originalOutput, loc); - if (loopedHostPublication) - return *loopedHostPublication; - - for (ProducerKey key : keys) { - if (key.instance.laneCount != 1) - return sourceClass.op->emitError("projected terminal batch host output expects one logical lane per fragment"); - if (failed(emitProjectedBatchHostFragment(state, sourceClass, key, payload, originalOutput, loc))) - return failure(); - } - return success(); - } - return emitHostCommunication(state, sourceClass, payload, originalOutput); } + bool recordedProjectedHostFragments = false; + if (hasLiveExternalUseCached(state, originalOutput)) { + FailureOr recorded = + recordProjectedScalarHostFragmentsFromPackedValue(state, sourceClass, keys, payload, originalOutput, loc); + if (failed(recorded)) + return failure(); + recordedProjectedHostFragments = *recorded; + } + if (!haveSameDestinationClasses(state, keys)) return sourceClass.op->emitError( "cannot materialize batched output whose lanes have different destination equivalence classes"); - if (auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp())) { - if (sourceBatch.getNumResults() != 0 && isTerminalHostBatchOutput(originalOutput, state.oldComputeOps)) { - for (ClassId destinationClass : getDestinationClasses(state, keys.front())) - if (!state.classes[destinationClass].isBatch) - return emitBatchToScalarDestinationDiagnostic(state, sourceClass, keys, originalOutput); - } - } - for (ClassId destinationClass : getDestinationClasses(state, keys.front())) if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) return failure(); - auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); - if (sourceBatch && sourceBatch.getNumResults() != 0 && hasLiveExternalUseCached(state, originalOutput)) { - if (sourceClass.hostOutputToResultIndex.contains(originalOutput)) { - if (failed(emitProjectedBatchHostOutput(state, sourceClass, keys, originalOutput, payload, loc))) - return failure(); - } - else { - auto ownerIt = state.hostOutputOwners.find(originalOutput); - if (ownerIt == state.hostOutputOwners.end()) - return sourceClass.op->emitError("missing host owner for projected batch output"); + if (hasLiveExternalUseCached(state, originalOutput) && !recordedProjectedHostFragments) + return sourceClass.op->emitError( + "batch host publication requires explicit fragment assembly metadata"); - MaterializedClass& ownerClass = state.classes[ownerIt->second]; - if (ownerClass.isBatch) - return ownerClass.op->emitError( - "projected batch host output reached a batch owner without an explicit batch publication path"); - - if (sourceClass.id != ownerClass.id - && failed(emitClassToClassCommunication(state, sourceClass, ownerClass, keys, payload, loc))) - return failure(); - - std::optional loopedHostPublication = - tryEmitProjectedBatchHostReceiveInsertLoop(state, ownerClass, originalOutput, keys, loc); - if (loopedHostPublication) { - if (failed(*loopedHostPublication)) - return failure(); - } - else { - for (ProducerKey key : keys) { - if (key.instance.laneCount != 1) - return sourceClass.op->emitError("projected batch host output expects one logical lane per fragment"); - - std::optional ownerPayload = state.availableValues.lookup(state, key, ownerClass.id); - if (!ownerPayload) - return ownerClass.op->emitError("failed to recover projected batch host fragment after communication"); - - if (failed(insertProjectedBatchHostFragment( - state, ownerClass, originalOutput, key.instance.laneStart, *ownerPayload))) - return failure(); - } - } - } - } else if (failed(emitHostCommunication(state, sourceClass, payload, originalOutput))) { + if (!recordedProjectedHostFragments && failed(emitHostCommunication(state, sourceClass, payload, originalOutput))) return failure(); - } for (ProducerKey key : keys) state.availableValues.record(key, sourceClass.id, payload); @@ -6302,6 +4398,13 @@ bool packedScalarRunSlotsMatch(const PackedScalarRunValue& lhs, const PackedScal } +std::optional getConstantIndexValue(Value value) { + APInt constant; + if (matchPattern(value, m_ConstantInt(&constant))) + return constant.getSExtValue(); + return std::nullopt; +} + bool appendConstantChannelReceiveMessage(MessageVector& messages, SpatChannelReceiveOp receive) { std::optional channelId = getConstantIndexValue(receive.getChannelId()); std::optional sourceCoreId = getConstantIndexValue(receive.getSourceCoreId()); @@ -6680,7 +4783,7 @@ LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state, size_t flattenedIndexBase = 0; for (auto [slotIndex, slot] : llvm::enumerate(run->slots)) { - std::optional contiguousKey = getPhysicallyContiguousProducerRangeForKeys(slot.keys); + std::optional contiguousKey = getContiguousProducerRangeForKeys(slot.keys); if (contiguousKey) { FailureOr slotPackedType = getPackedBatchTensorType(run->fragmentType, slot.keys.size()); if (failed(slotPackedType)) @@ -7311,9 +5414,6 @@ FailureOr materializeProjectedWholeBatchInputFromFragments(MaterializerSt FailureOr materializeWholeBatchInput( MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, Type resultType, Location loc) { - if (failed(materializePendingScalarReceivesForWholeBatchInput(state, targetClass, key, loc))) - return failure(); - FailureOr plan = buildWholeBatchAssemblyPlan(state, targetClass, key, resultType); if (succeeded(plan)) return emitWholeBatchAssemblyPlan(state, targetClass, key, *plan, loc); @@ -7321,6 +5421,327 @@ FailureOr materializeWholeBatchInput( return materializeProjectedWholeBatchInputFromFragments(state, targetClass, key, resultType, loc); } +FailureOr recordProjectedScalarHostFragmentsFromPackedRun(MaterializerState& state, + MaterializedClass& sourceClass, + SpatComputeBatch sourceBatch, + size_t resultIndex, + ArrayRef run, + Value packed, + RankedTensorType fragmentType, + Value originalOutput, + Location loc) { + if (!hasLiveExternalUseCached(state, originalOutput)) + return false; + if (packed.getType() == originalOutput.getType() || fragmentType == originalOutput.getType()) + return false; + + auto resultType = dyn_cast(originalOutput.getType()); + if (!resultType || !resultType.hasStaticShape()) + return false; + + FailureOr projection = getBatchResultProjectionInsert(sourceBatch, resultIndex); + if (failed(projection)) + return false; + + std::optional laneArg = sourceBatch.getLaneArgument(); + if (!laneArg) { + sourceBatch.emitOpError("missing compute_batch lane argument while recording projected host fragments"); + return failure(); + } + + for (auto [runIndex, slot] : llvm::enumerate(run)) { + if (slot.peers.size() != 1) { + sourceClass.op->emitError("projected scalar host output publication expects scalar one-peer run slots"); + return failure(); + } + + const ComputeInstance& peer = slot.peers.front(); + if (peer.op != sourceBatch.getOperation()) { + sourceClass.op->emitError("projected scalar host output run changed source operation"); + return failure(); + } + if (peer.laneCount != 1) { + sourceClass.op->emitError("projected scalar host output publication expects one logical lane per packed slot") + << " laneStart=" << peer.laneStart << " laneCount=" << peer.laneCount; + return failure(); + } + + FailureOr> offsets = + evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, peer.laneStart); + FailureOr> sizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, peer.laneStart); + FailureOr> strides = + evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, peer.laneStart); + if (failed(offsets) || failed(sizes) || failed(strides)) { + sourceClass.op->emitError("failed to evaluate projected host output slice for logical lane ") + << peer.laneStart; + return failure(); + } + + state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment { + originalOutput, + sourceClass.id, + packed, + cast(packed.getType()), + fragmentType, + static_cast(runIndex), + SmallVector(*offsets), + SmallVector(*sizes), + SmallVector(*strides), + peer.laneStart, + loc}); + } + + return true; +} + +FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + Value packed, + Value originalOutput, + Location loc) { + if (!sourceClass.isBatch || keys.empty()) + return false; + if (!hasLiveExternalUseCached(state, originalOutput)) + return false; + if (packed.getType() == originalOutput.getType()) + return false; + + auto resultType = dyn_cast(originalOutput.getType()); + auto packedType = dyn_cast(packed.getType()); + auto sourceBatch = dyn_cast_or_null(keys.front().instance.op); + if (!resultType || !resultType.hasStaticShape() || !packedType || !packedType.hasStaticShape() || !sourceBatch) + return false; + if (keys.front().resultIndex >= static_cast(sourceBatch.getNumResults())) + return false; + + FailureOr projection = + getBatchResultProjectionInsert(sourceBatch, keys.front().resultIndex); + if (failed(projection)) + return false; + + auto laneArg = sourceBatch.getLaneArgument(); + if (!laneArg) + return sourceBatch.emitOpError("missing compute_batch lane argument while recording projected host fragments"); + + FailureOr> firstSizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, keys.front().instance.laneStart); + if (failed(firstSizes)) + return sourceClass.op->emitError("failed to evaluate projected host output slice for logical lane ") + << keys.front().instance.laneStart; + + SmallVector fragmentShape(*firstSizes); + auto fragmentType = RankedTensorType::get(fragmentShape, packedType.getElementType(), packedType.getEncoding()); + if (fragmentType == originalOutput.getType()) + return false; + + bool operandIsDim0Packed = false; + if (packedType != fragmentType) { + if (packedType.getRank() == 0 || packedType.getDimSize(0) % static_cast(keys.size()) != 0) + return sourceClass.op->emitError( + "projected packed host publication requires either direct fragment operands or evenly dim-0 packed fragments") + << " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size(); + + SmallVector packedFragmentShape(packedType.getShape()); + packedFragmentShape[0] /= static_cast(keys.size()); + if (packedFragmentShape != fragmentShape) + return sourceClass.op->emitError( + "projected packed host publication fragment shape does not match projected slice size") + << " packedType=" << packedType << " fragmentType=" << fragmentType << " keyCount=" << keys.size(); + operandIsDim0Packed = true; + } + + for (auto [fragmentIndex, key] : llvm::enumerate(keys)) { + if (key.instance.op != sourceBatch.getOperation() || key.resultIndex != keys.front().resultIndex || key.instance.laneCount != 1) + return sourceClass.op->emitError("projected packed host publication requires one-lane keys from one producer result"); + + FailureOr> offsets = + evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, key.instance.laneStart); + FailureOr> sizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, key.instance.laneStart); + FailureOr> strides = + evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, key.instance.laneStart); + if (failed(offsets) || failed(sizes) || failed(strides)) + return sourceClass.op->emitError("failed to evaluate projected host output slice for logical lane ") + << key.instance.laneStart; + if (SmallVector(*sizes) != fragmentShape) + return sourceClass.op->emitError( + "projected packed host publication requires one operand to map to a consistent fragment shape"); + + state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment { + originalOutput, + sourceClass.id, + packed, + packedType, + fragmentType, + operandIsDim0Packed ? static_cast(fragmentIndex) : -1, + SmallVector(*offsets), + SmallVector(*sizes), + SmallVector(*strides), + key.instance.laneStart, + loc}); + } + + return true; +} + +LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { + if (state.pendingProjectedHostOutputFragments.empty()) + return success(); + + DenseMap> byOutput; + for (PendingProjectedHostOutputFragment& fragment : state.pendingProjectedHostOutputFragments) + byOutput[fragment.originalOutput].push_back(&fragment); + + SmallVector outputs; + outputs.reserve(byOutput.size()); + for (const auto& entry : byOutput) + outputs.push_back(entry.first); + llvm::sort(outputs, [](Value lhs, Value rhs) { + return reinterpret_cast(lhs.getAsOpaquePointer()) + < reinterpret_cast(rhs.getAsOpaquePointer()); + }); + + for (Value originalOutput : outputs) { + auto ownerIt = state.hostOutputOwners.find(originalOutput); + if (ownerIt == state.hostOutputOwners.end()) { + Operation* anchor = originalOutput.getDefiningOp() ? originalOutput.getDefiningOp() : state.func.getOperation(); + return anchor->emitError("missing host owner for projected host output fragments"); + } + + MaterializedClass* ownerClass = &state.classes[ownerIt->second]; + if (ownerClass->isBatch) { + auto scalarOwnerIt = llvm::find_if(state.classes, [](const MaterializedClass& candidate) { + return !candidate.isBatch; + }); + if (scalarOwnerIt == state.classes.end()) + return ownerClass->op->emitError( + "projected host output finalization requires a scalar assembly class when the preferred host owner is batch"); + ownerClass = &*scalarOwnerIt; + state.hostOutputOwners[originalOutput] = ownerClass->id; + } + + auto resultType = dyn_cast(originalOutput.getType()); + if (!resultType || !resultType.hasStaticShape()) + return ownerClass->op->emitError("projected host output must have static ranked tensor type"); + + SmallVector& fragments = byOutput[originalOutput]; + llvm::sort(fragments, [](const PendingProjectedHostOutputFragment* lhs, + const PendingProjectedHostOutputFragment* rhs) { + if (lhs->sourceLane != rhs->sourceLane) + return lhs->sourceLane < rhs->sourceLane; + if (lhs->sourceClass != rhs->sourceClass) + return lhs->sourceClass < rhs->sourceClass; + return std::lexicographical_compare(lhs->offsets.begin(), + lhs->offsets.end(), + rhs->offsets.begin(), + rhs->offsets.end()); + }); + + bool allFromSameSourceClass = + llvm::all_of(fragments, [&](const PendingProjectedHostOutputFragment* fragment) { + return fragment->sourceClass == fragments.front()->sourceClass; + }); + if (allFromSameSourceClass) { + ownerClass = &state.classes[fragments.front()->sourceClass]; + state.hostOutputOwners[originalOutput] = ownerClass->id; + } + + state.rewriter.setInsertionPoint(ownerClass->body->getTerminator()); + Location loc = fragments.front()->loc; + SmallVector reconciliatorOperands; + SmallVector fragmentOperandIndices; + SmallVector flatOffsets; + SmallVector flatSizes; + SmallVector flatStrides; + DenseMap operandIndicesByValue; + + for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) { + Value operand = fragmentRecord->operand; + MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass]; + + if (fragmentRecord->sourceClass != ownerClass->id) { + if (sourceClass.isBatch || ownerClass->isBatch) + return sourceClass.op->emitError( + "projected host output fragment assembly requires scalarized cross-class operands before finalization"); + MessageVector messages; + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, + sourceClass.cpus.front(), + "projected host output source core id"); + auto checkedTargetCpu = getCheckedCoreId(ownerClass->op, + ownerClass->cpus.front(), + "projected host output target core id"); + if (failed(checkedSourceCpu) || failed(checkedTargetCpu)) + return failure(); + messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); + if (failed(appendSend(state, sourceClass, operand, messages, fragmentRecord->loc))) + return failure(); + operand = appendReceive(state, *ownerClass, fragmentRecord->operandType, messages, fragmentRecord->loc); + } else if (!ownerClass->isBatch) { + FailureOr localOperand = materializeTensorValueForMaterializedClassUse( + state, + *ownerClass, + operand, + ownerClass->op, + "projected host output assembly tried to reuse a non-local fragment tensor"); + if (failed(localOperand)) + return failure(); + operand = *localOperand; + } + + auto [operandIt, inserted] = + operandIndicesByValue.try_emplace(operand, static_cast(reconciliatorOperands.size())); + if (inserted) + reconciliatorOperands.push_back(operand); + fragmentOperandIndices.push_back(operandIt->second); + llvm::append_range(flatOffsets, fragmentRecord->offsets); + llvm::append_range(flatSizes, fragmentRecord->sizes); + llvm::append_range(flatStrides, fragmentRecord->strides); + + auto operandType = dyn_cast(operand.getType()); + if (!operandType || !operandType.hasStaticShape()) + return ownerClass->op->emitError("projected host output assembly requires static ranked tensor operands"); + if (fragmentRecord->packedFragmentIndex >= 0) { + int64_t fragmentSize0 = fragmentRecord->fragmentType.getDimSize(0); + if (fragmentSize0 <= 0 || operandType.getRank() == 0) + return ownerClass->op->emitError("packed projected host output assembly requires ranked fragment operands"); + int64_t start = fragmentRecord->packedFragmentIndex * fragmentSize0; + int64_t end = start + fragmentSize0; + if (start < 0 || end > operandType.getDimSize(0)) + return ownerClass->op->emitError("packed projected host output fragment index is out of bounds"); + } + } + + if (reconciliatorOperands.empty()) + return ownerClass->op->emitError("missing projected host output fragments"); + + Value input = reconciliatorOperands.front(); + ValueRange extraFragments = ValueRange(reconciliatorOperands).drop_front(); + auto reconciliator = spatial::SpatReconciliatorOp::create( + state.rewriter, + loc, + resultType, + input, + extraFragments, + state.rewriter.getStringAttr("nchw"), + state.rewriter.getStringAttr("fragmented"), + state.rewriter.getDenseI64ArrayAttr(flatOffsets), + state.rewriter.getDenseI64ArrayAttr(flatSizes), + state.rewriter.getStringAttr("identity"), + state.rewriter.getStringAttr("fragment_assembly"), + state.rewriter.getDenseI64ArrayAttr(fragmentOperandIndices), + state.rewriter.getDenseI64ArrayAttr(flatStrides), + state.rewriter.getStringAttr("disjoint"), + state.rewriter.getStringAttr("complete")); + + if (failed(setHostOutputValue(state, *ownerClass, originalOutput, reconciliator.getOutput()))) + return failure(); + } + + return success(); +} + FailureOr resolveInputValue(MaterializerState& state, MaterializedClass& targetClass, Value input, @@ -7356,13 +5777,6 @@ FailureOr resolveInputValue(MaterializerState& state, if (std::optional value = state.availableValues.lookup(state, *producer, targetClass.id)) return rejectNonLocalResolvedValue(*value); - if (auto pendingReceive = lookupPendingScalarReceiveIndex(state, *producer, targetClass.id)) { - FailureOr received = - materializePendingScalarReceive(state, targetClass, *pendingReceive, consumerInstance.op->getLoc()); - if (failed(received)) - return failure(); - return rejectNonLocalResolvedValue(*received); - } if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) { size_t laneCount = targetClass.cpus.size(); @@ -7495,9 +5909,23 @@ LogicalResult mapInputs(MaterializerState& state, if (hasProjectedInputReplacement(state, batch, static_cast(index), targetClass.id)) continue; - FailureOr mapped = resolveInputValue(state, targetClass, input, instance, indexing); - if (failed(mapped)) - return batch.emitOpError("failed to resolve materialized compute_batch input"); + FailureOr mapped = failure(); + if (std::optional wholeBatchProducer = getWholeBatchProducerKeyForDirectBatchResult(input); + wholeBatchProducer && !canUseProjectedLaneInput(state, batch, static_cast(index), input, instance)) { + mapped = materializeWholeBatchInput( + state, targetClass, *wholeBatchProducer, input.getType(), batch.getOperation()->getLoc()); + if (failed(mapped)) + return batch.emitOpError("failed to materialize whole-batch compute_batch input") + << " #" << index << " from '" << wholeBatchProducer->instance.op->getName() + << "' laneStart=" << wholeBatchProducer->instance.laneStart + << " laneCount=" << wholeBatchProducer->instance.laneCount + << " resultIndex=" << wholeBatchProducer->resultIndex; + } else { + mapped = resolveInputValue(state, targetClass, input, instance, indexing); + if (failed(mapped)) + return batch.emitOpError("failed to resolve materialized compute_batch input"); + } + auto inputArg = batch.getInputArgument(index); if (!inputArg) return batch.emitOpError("expected compute_batch input block argument while materializing inputs"); @@ -7599,29 +6027,6 @@ std::optional lookupProjectedExtractReplacement(Mat return classIt->second; } -bool requiresConstantProjectionSlotIndex(MaterializerState& state, - MaterializedClass& targetClass, - Operation* sourceOp) { - bool requiresConstantIndex = false; - sourceOp->walk([&](tensor::ExtractSliceOp extract) { - if (requiresConstantIndex) - return WalkResult::interrupt(); - - std::optional replacement = - lookupProjectedExtractReplacement(state, targetClass, extract); - if (!replacement) - return WalkResult::advance(); - - if (replacement->layout.payloadFragmentCount != replacement->layout.fragmentsPerLogicalSlot) { - requiresConstantIndex = true; - return WalkResult::interrupt(); - } - - return WalkResult::advance(); - }); - return requiresConstantIndex; -} - LogicalResult applyProjectedExtractReplacementsInClonedOp(MaterializerState& state, MaterializedClass& targetClass, Operation& originalOp, @@ -8365,14 +6770,8 @@ FailureOr> materializeBatchOutputGroupLoop(MaterializerSta for (auto [outputIndex, output] : llvm::enumerate(*produced)) { auto fragmentType = cast(output.getType()); Value acc = iterArgs[outputIndex]; - FailureOr firstOffset = - scaleIndexByDim0SizeInClass(state, targetClass, loopIndex, fragmentType.getDimSize(0), loc); - if (failed(firstOffset)) - return failure(); - FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, output, acc, *firstOffset); - if (failed(next)) - return failure(); - yielded.push_back(*next); + Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc); + yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset)); } return success(); }); @@ -8512,20 +6911,6 @@ bool hasMaterializationRunGroupSameClassConsumer(MaterializerState& state, return false; } -bool canRegisterDeferredLocalPackedRun(MaterializerState& state, ArrayRef run) { - for (const MaterializationRunSlot& slot : run) { - for (const ComputeInstance& peer : slot.peers) { - for (Value input : getComputeInstanceInputs(peer)) { - std::optional producer = getInputRequestProducerKey(input, peer); - if (producer && isWholeBatchProducerKey(*producer)) - return false; - } - } - } - - return true; -} - void markMaterializationRunSlots(MaterializerState& state, ClassId classId, SlotId startSlot, @@ -8549,13 +6934,11 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state, auto sourceBatch = cast(getMaterializationRunSourceOp(run)); SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); Location loc = getMaterializationRunLoc(run); - bool canDeferLocalPackedRun = canRegisterDeferredLocalPackedRun(state, run); for (const OutputDestinationGroup& group : groups) { - bool canUseLocalOnlyPackedRun = run.size() > 1 && group.destinationClasses.empty() - && !hasMaterializationRunGroupLiveExternalUse(state, run, group) - && !hasMaterializationRunGroupSameClassConsumer(state, targetClass.id, run, group); - if (canUseLocalOnlyPackedRun && canDeferLocalPackedRun) { + if (run.size() > 1 && group.destinationClasses.empty() + && !hasMaterializationRunGroupLiveExternalUse(state, run, group) + && !hasMaterializationRunGroupSameClassConsumer(state, targetClass.id, run, group)) { for (size_t resultIndex : group.resultIndices) { if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) return sourceBatch.emitOpError("failed to recover per-lane output type for deferred local packed run"); @@ -8580,14 +6963,21 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state, Type fragmentType = fragmentTypes[resultIndex]; SmallVector keys = getMaterializationRunOutputKeys(run, resultIndex); - if (run.size() == 1) { - if (failed(emitOutputFanout(state, targetClass, keys, packed, firstOriginalOutputs[resultIndex], loc))) - return failure(); - continue; - } + auto rankedFragmentType = cast(fragmentType); + Value representativeOriginalOutput = firstOriginalOutputs[resultIndex]; + FailureOr recordedProjectedHostFragments = recordProjectedScalarHostFragmentsFromPackedRun( + state, targetClass, sourceBatch, resultIndex, run, packed, rankedFragmentType, representativeOriginalOutput, loc); + if (failed(recordedProjectedHostFragments)) + return failure(); - if (canUseLocalOnlyPackedRun) { - if (failed(registerPackedRunValue(state, targetClass, keys, packed, fragmentType, loc))) + if (run.size() == 1) { + if (*recordedProjectedHostFragments) { + if (failed(emitScalarSourceCommunication(state, targetClass, keys, packed, loc))) + return failure(); + continue; + } + + if (failed(emitOutputFanout(state, targetClass, keys, packed, representativeOriginalOutput, loc))) return failure(); continue; } @@ -8598,19 +6988,9 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state, if (failed(registerPackedRunValue(state, targetClass, keys, packed, fragmentType, loc))) return failure(); - Value representativeOutput = firstOriginalOutputs[resultIndex]; - if (hasLiveExternalUseCached(state, representativeOutput) - && isProjectedTerminalBatchHostOutput(representativeOutput, state.oldComputeOps)) { - std::optional groupedHostPublication = - tryEmitScalarPackedProjectedHostPublication(state, targetClass, keys, packed, representativeOutput, loc); - if (groupedHostPublication) { - if (failed(*groupedHostPublication)) - return failure(); - continue; - } - } + if (*recordedProjectedHostFragments) + continue; - auto rankedFragmentType = cast(fragmentType); for (auto [runIndex, slot] : llvm::enumerate(run)) { assert(slot.peers.size() == 1 && "scalar materialization run slot must contain exactly one peer"); @@ -8621,19 +7001,9 @@ LogicalResult materializeScalarBatchRun(MaterializerState& state, continue; state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - FailureOr fragment = - getPackedSliceForRunIndex(state, targetClass, packed, rankedFragmentType, runIndex, loc); - if (failed(fragment)) - return failure(); + Value fragment = getPackedSliceForRunIndex(state, targetClass.op, packed, rankedFragmentType, runIndex, loc); - if (isProjectedTerminalBatchHostOutput(originalOutput, state.oldComputeOps)) { - ProducerKey key {slot.peers.front(), resultIndex}; - if (failed(emitProjectedBatchHostFragment(state, targetClass, key, *fragment, originalOutput, loc))) - return failure(); - continue; - } - - if (failed(emitHostCommunication(state, targetClass, *fragment, originalOutput))) + if (failed(emitHostCommunication(state, targetClass, fragment, originalOutput))) return failure(); } } @@ -8689,88 +7059,6 @@ bool canCompactBatchClassRun(MaterializerState& state, return true; } -LogicalResult registerMaterializedBatchRunHostOutputs(MaterializerState& state, - MaterializedClass& targetClass, - ArrayRef run, - const OutputDestinationGroup& group) { - ArrayRef originalOutputs = getFirstMaterializationRunOriginalOutputs(state, run); - for (size_t resultIndex : group.resultIndices) { - if (resultIndex >= originalOutputs.size()) - return targetClass.op->emitError("batch materialization host output index out of range"); - - Value originalOutput = originalOutputs[resultIndex]; - if (!hasLiveExternalUseCached(state, originalOutput)) - continue; - - auto resultIt = targetClass.hostOutputToResultIndex.find(originalOutput); - if (resultIt == targetClass.hostOutputToResultIndex.end()) - return targetClass.op->emitError("missing host result slot for materialized batch output"); - - state.hostReplacements[originalOutput] = targetClass.op->getResult(resultIt->second); - } - - return success(); -} - -LogicalResult verifyMaterializedHostOutputs(MaterializerState& state) { - for (SpatCompute compute : state.func.getOps()) { - auto yieldOp = dyn_cast_or_null(compute.getBody().front().getTerminator()); - if (!yieldOp) - return compute.emitOpError("expected spat.yield terminator in materialized compute"); - if (compute.getNumResults() != yieldOp.getNumOperands()) - return compute.emitOpError("materialized compute result count does not match spat.yield operand count"); - for (auto [result, yielded] : llvm::zip(compute.getResults(), yieldOp.getOperands())) - if (result.getType() != yielded.getType()) - return compute.emitOpError("ComputeOp output must be of the same type as yieldOp operand"); - } - - for (SpatChannelReceiveOp receive : state.func.getOps()) { - if (!receive.getOutput().use_empty()) - continue; - return receive.emitOpError("materialized channel_receive result must have at least one use"); - } - - for (const MaterializedClass& materializedClass : state.classes) { - if (!materializedClass.isBatch || materializedClass.hostOutputs.empty()) - continue; - - auto batch = dyn_cast(materializedClass.op); - auto inParallel = dyn_cast_or_null(materializedClass.body->getTerminator()); - if (!batch || !inParallel) - return materializedClass.op->emitError("expected resultful materialized compute_batch host owner"); - - for (Value hostOutput : materializedClass.hostOutputs) { - auto ownerIt = materializedClass.hostOutputToResultIndex.find(hostOutput); - if (ownerIt == materializedClass.hostOutputToResultIndex.end()) - return materializedClass.op->emitError("missing host result slot for materialized compute_batch host output"); - - auto outputArg = batch.getOutputArgument(ownerIt->second); - if (!outputArg) - return batch.emitOpError("missing output block argument for materialized compute_batch host output"); - - bool foundProjection = false; - for (Operation& op : inParallel.getRegion().front()) { - auto insert = dyn_cast(&op); - if (!insert || insert.getDest() != *outputArg) - continue; - foundProjection = true; - break; - } - - if (!foundProjection) - return batch.emitOpError( - "materialized terminal compute_batch host output is missing tensor.parallel_insert_slice publication"); - } - } - - for (const auto& [originalOutput, replacement] : state.hostReplacements) - if (originalOutput.getType() != replacement.getType()) - return replacement.getDefiningOp()->emitOpError("host output replacement type does not match original output type") - << " replacementType=" << replacement.getType() << " outputType=" << originalOutput.getType(); - - return success(); -} - Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targetClass, Value slotIndex, Location loc) { auto batch = cast(targetClass.op); auto laneArg = batch.getLaneArgument(); @@ -8977,7 +7265,6 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, auto sourceBatch = cast(run.front().peers.front().op); SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); Location loc = sourceBatch.getLoc(); - bool constantProjectionSlotIndex = requiresConstantProjectionSlotIndex(state, targetClass, sourceBatch); for (const OutputDestinationGroup& group : groups) { SmallVector sendPlans; @@ -8988,15 +7275,17 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.size())); Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); - if (constantProjectionSlotIndex) { - for (auto [slotIndex, slot] : llvm::enumerate(run)) { - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - - Value slotIndexValue = - getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(slotIndex)); - Value sourceLane = getOrCreateIndexConstant(state.constantFolder, targetClass.op, slot.peers.front().laneStart); - Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndexValue, loc); + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {}, + [&](OpBuilder&, Location, Value slotIndex, ValueRange, SmallVectorImpl&) { + Value sourceLane = createBatchClassRunSourceLane(state, targetClass, run, slotIndex, loc); + Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); FailureOr> produced = cloneBatchBodyForLane(state, @@ -9004,8 +7293,7 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, getScheduledChunkForLogicalInstance(state, run.front().peers.front()), sourceLane, group.resultIndices, - CloneIndexingContext {.runSlotIndex = slotIndexValue, - .projectionSlotIndex = slotIndexValue}); + CloneIndexingContext {.runSlotIndex = slotIndex, .projectionSlotIndex = slotIndex}); if (failed(produced)) return failure(); @@ -9017,43 +7305,10 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, size_t groupOutputIndex = static_cast(std::distance(group.resultIndices.begin(), resultIt)); appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc); } - } - } else { - state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - auto loop = buildNormalizedScfFor( - state.rewriter, - loc, - lowerBound, - upperBound, - step, - ValueRange {}, - [&](OpBuilder&, Location, Value slotIndex, ValueRange, SmallVectorImpl&) { - Value sourceLane = createBatchClassRunSourceLane(state, targetClass, run, slotIndex, loc); - Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); - - FailureOr> produced = - cloneBatchBodyForLane(state, - targetClass, - getScheduledChunkForLogicalInstance(state, run.front().peers.front()), - sourceLane, - group.resultIndices, - CloneIndexingContext {.runSlotIndex = slotIndex, .projectionSlotIndex = slotIndex}); - if (failed(produced)) - return failure(); - - for (const BatchRunSendPlan& plan : sendPlans) { - auto resultIt = llvm::find(group.resultIndices, plan.resultIndex); - if (resultIt == group.resultIndices.end()) - return failure(); - - size_t groupOutputIndex = static_cast(std::distance(group.resultIndices.begin(), resultIt)); - appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc); - } - return success(); - }); - if (failed(loop)) - return failure(); - } + return success(); + }); + if (failed(loop)) + return failure(); for (const BatchRunSendPlan& plan : sendPlans) { if (plan.resultIndex >= fragmentTypes.size() || !fragmentTypes[plan.resultIndex]) @@ -9062,9 +7317,6 @@ LogicalResult materializeBatchClassRun(MaterializerState& state, if (failed(appendBatchRunReceives(state, targetClass, run, plan, fragmentTypes[plan.resultIndex], loc))) return failure(); } - - if (failed(registerMaterializedBatchRunHostOutputs(state, targetClass, run, group))) - return failure(); } return success(); @@ -9147,8 +7399,7 @@ FailureOr createReceiveConcatLoop(MaterializerState& state, assert(succeeded(messages.verify(targetClass.op)) && "message metadata is inconsistent"); assert(!messages.empty() && "expected at least one receive"); - Operation* insertionPoint = targetClass.body->getTerminator(); - state.rewriter.setInsertionPoint(insertionPoint); + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); Value init = tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult(); return emitIndexedFragmentInsertLoop( @@ -9169,267 +7420,101 @@ FailureOr createReceiveConcatLoop(MaterializerState& state, loc); } +bool valueMayEvaluateToCore(Value value, int64_t coreId) { + if (std::optional constant = getConstantIndexValue(value)) + return *constant == coreId; -std::optional getDirectCommunicationOrderKey(Operation* op) { - if (!op) - return std::nullopt; - - Value channelId; - Value sourceCoreId; - Value targetCoreId; - if (auto send = dyn_cast(op)) { - channelId = send.getChannelId(); - sourceCoreId = send.getSourceCoreId(); - targetCoreId = send.getTargetCoreId(); - } - else if (auto receive = dyn_cast(op)) { - channelId = receive.getChannelId(); - sourceCoreId = receive.getSourceCoreId(); - targetCoreId = receive.getTargetCoreId(); - } - else { - return std::nullopt; - } - - auto channel = getConstantIndexValue(channelId); - auto source = getConstantIndexValue(sourceCoreId); - auto target = getConstantIndexValue(targetCoreId); - if (!channel || !source || !target) - return std::nullopt; - - return computeBlockingCommunicationOrderKey( - static_cast(*source), static_cast(*target), *channel); -} - -std::optional getScalarCommunicationOrderKey(Operation* op) { - if (!op) - return std::nullopt; - if (auto order = op->getAttrOfType(kRaptorCommOrderAttr)) - return order.getInt(); - if (auto directOrder = getDirectCommunicationOrderKey(op)) - return directOrder; - if (auto channel = op->getAttrOfType(kRaptorMinChannelIdAttr)) - return channel.getInt(); - return std::nullopt; -} - -bool isReorderableScalarCommunication(Operation* op) { - if (!getScalarCommunicationOrderKey(op).has_value()) + auto affineApply = value.getDefiningOp(); + if (!affineApply) return false; - // The global-order repair is intentionally conservative: it may reorder - // send-side projections, but it must not move receives or any other - // communication op that defines SSA values. Moving a receive after one of - // its users breaks MLIR dominance; moving it before the source can produce - // the payload can also create a receive/receive deadlock. Receives therefore - // have to be placed correctly by the materializer when they are created. - // Direct spat.channel_send operations are included even when they were not - // produced by appendScalarSendLoop and therefore do not carry raptor.* - // attributes yet. This is needed for large scalar-to-scalar payload transfers - // that must be hoisted before reciprocal receives. - return isa(op) || (op->getNumResults() == 0 && op->hasAttr(kRaptorMinChannelIdAttr)); + AffineMap map = affineApply.getAffineMap(); + if (map.getNumResults() != 1 || map.getNumDims() != 1 || map.getNumSymbols() != 0 + || affineApply.getMapOperands().size() != 1) + return false; + + auto iv = dyn_cast(affineApply.getMapOperands().front()); + if (!iv) + return false; + + auto loop = dyn_cast_or_null(iv.getOwner()->getParentOp()); + if (!loop || loop.getInductionVar() != iv) + return false; + + std::optional lower = getConstantIndexValue(loop.getLowerBound()); + std::optional upper = getConstantIndexValue(loop.getUpperBound()); + std::optional step = getConstantIndexValue(loop.getStep()); + if (!lower || !upper || !step || *step <= 0) + return false; + + for (int64_t iteration = *lower; iteration < *upper; iteration += *step) { + FailureOr evaluated = evaluateSingleResultAffineMap(map, ArrayRef{iteration}); + if (succeeded(evaluated) && *evaluated == coreId) + return true; + } + + return false; } -Operation* getLaterOperationInBlock(Operation* lhs, Operation* rhs) { - if (!lhs) - return rhs; - if (!rhs) - return lhs; - return lhs->isBeforeInBlock(rhs) ? rhs : lhs; -} - -Operation* getNextInsertionPointAfter(Operation* op, Block& block) { - if (!op) - return &block.front(); - Operation* next = op->getNextNode(); - return next ? next : block.getTerminator(); -} - -bool hasConstantRoutingOperands(SpatChannelSendOp send) { - return getConstantIndexValue(send.getChannelId()).has_value() - && getConstantIndexValue(send.getSourceCoreId()).has_value() - && getConstantIndexValue(send.getTargetCoreId()).has_value(); -} - -Operation* getLatestSameBlockOperandDefinition(Operation* root, Block& block) { - Operation* latest = nullptr; - - auto consider = [&](Value value) { - Operation* definingOp = value.getDefiningOp(); - if (!definingOp || definingOp->getBlock() != &block || definingOp == root) +bool operationContainsReceiveFromPeer(Operation& op, int64_t localCore, int64_t peerCore, Type payloadType) { + bool found = false; + op.walk([&](SpatChannelReceiveOp receive) { + if (receive.getOutput().getType() != payloadType) return; - latest = getLaterOperationInBlock(latest, definingOp); - }; - - // For direct sends with constant routing operands, only the payload is a real - // scheduling dependency. The channel/source/target constants can be - // rematerialized at the new insertion point. Treating those constants as hard - // dependencies prevents the repair from hoisting a ready send above an early - // receive, which is exactly the receive/receive deadlock pattern reported by - // the static communication checker. - if (auto send = dyn_cast(root)) { - if (hasConstantRoutingOperands(send)) { - consider(send.getInput()); - return latest; - } - } - - for (Value operand : root->getOperands()) - consider(operand); - - for (Region& region : root->getRegions()) { - region.walk([&](Operation* nested) { - if (nested == root) - return; - for (Value operand : nested->getOperands()) - consider(operand); - }); - } - - return latest; -} - -void rematerializeDirectSendRoutingConstantsAt(MaterializerState& state, - SpatChannelSendOp send, - Operation* insertionPoint) { - if (!send || !insertionPoint || !hasConstantRoutingOperands(send)) - return; - - auto channel = getConstantIndexValue(send.getChannelId()); - auto source = getConstantIndexValue(send.getSourceCoreId()); - auto target = getConstantIndexValue(send.getTargetCoreId()); - if (!channel || !source || !target) - return; - - OpBuilder::InsertionGuard guard(state.rewriter); - state.rewriter.setInsertionPoint(insertionPoint); - Location loc = send.getLoc(); - Value newChannel = arith::ConstantIndexOp::create(state.rewriter, loc, *channel); - Value newSource = arith::ConstantIndexOp::create(state.rewriter, loc, *source); - Value newTarget = arith::ConstantIndexOp::create(state.rewriter, loc, *target); - send->setOperand(0, newChannel); - send->setOperand(1, newSource); - send->setOperand(2, newTarget); -} - -LogicalResult reorderScalarClassCommunicationByGlobalOrder(MaterializerState& state, - MaterializedClass& materializedClass) { - if (materializedClass.isBatch) - return success(); - - Block& block = *materializedClass.body; - Operation* terminator = block.getTerminator(); - SmallVector communicationOps; - for (Operation& op : block) { - if (&op == terminator) - break; - if (isReorderableScalarCommunication(&op)) - communicationOps.push_back(&op); - } - - if (communicationOps.size() < 2) - return success(); - - llvm::stable_sort(communicationOps, [](Operation* lhs, Operation* rhs) { - std::optional lhsOrder = getScalarCommunicationOrderKey(lhs); - std::optional rhsOrder = getScalarCommunicationOrderKey(rhs); - if (lhsOrder != rhsOrder) - return lhsOrder.value_or(std::numeric_limits::max()) - < rhsOrder.value_or(std::numeric_limits::max()); - return lhs->isBeforeInBlock(rhs); + if (!valueMayEvaluateToCore(receive.getTargetCoreId(), localCore)) + return; + if (!valueMayEvaluateToCore(receive.getSourceCoreId(), peerCore)) + return; + found = true; }); - - Operation* lastPlacedCommunication = nullptr; - for (Operation* communication : communicationOps) { - if (communication->getBlock() != &block) - return materializedClass.op->emitError("scalar communication global-order repair saw a moved operation"); - - Operation* dependency = getLatestSameBlockOperandDefinition(communication, block); - Operation* anchor = getLaterOperationInBlock(lastPlacedCommunication, dependency); - Operation* insertionPoint = getNextInsertionPointAfter(anchor, block); - - if (insertionPoint != communication && communication->getNextNode() != insertionPoint) { - if (auto send = dyn_cast(communication)) - rematerializeDirectSendRoutingConstantsAt(state, send, insertionPoint); - communication->moveBefore(insertionPoint); - } - - lastPlacedCommunication = communication; - } - - return success(); + return found; } -LogicalResult reorderScalarCommunicationsByGlobalOrder(MaterializerState& state) { - for (MaterializedClass& materializedClass : state.classes) - if (failed(reorderScalarClassCommunicationByGlobalOrder(state, materializedClass))) - return failure(); - return success(); -} - - -Operation* getEarliestOperationInBlock(Operation* lhs, Operation* rhs) { - if (!lhs) - return rhs; - if (!rhs) - return lhs; - return lhs->isBeforeInBlock(rhs) ? lhs : rhs; -} - -Operation* getTopLevelOperationInBlock(Operation* op, Block& block) { - for (Operation* current = op; current; current = current->getParentOp()) { - if (current->getBlock() == &block) - return current; - } - return nullptr; -} - -Operation* findEarliestTopLevelUse(Operation* producer, Block& block) { - Operation* earliest = nullptr; - for (Value result : producer->getResults()) { - for (Operation* user : result.getUsers()) { - Operation* topLevelUser = getTopLevelOperationInBlock(user, block); - if (!topLevelUser || topLevelUser == producer) - continue; - earliest = getEarliestOperationInBlock(earliest, topLevelUser); - } - } - return earliest; -} - -LogicalResult sinkScalarReceivesToFirstUse(MaterializerState& state) { +LogicalResult orderLowerCoreScalarSendsAfterMatchingReceives(MaterializerState& state) { for (MaterializedClass& materializedClass : state.classes) { - if (materializedClass.isBatch) + if (materializedClass.isBatch || materializedClass.cpus.empty()) continue; - Block& block = *materializedClass.body; - Operation* terminator = block.getTerminator(); - SmallVector receives; - for (Operation& op : block) { - if (&op == terminator) + int64_t localCore = static_cast(materializedClass.cpus.front()); + Block* body = materializedClass.body; + if (!body) + continue; + + bool changed = true; + while (changed) { + changed = false; + for (Operation& op : llvm::make_early_inc_range(*body)) { + if (&op == body->getTerminator()) + break; + + auto send = dyn_cast(&op); + if (!send) + continue; + + std::optional sourceCore = getConstantIndexValue(send.getSourceCoreId()); + std::optional targetCore = getConstantIndexValue(send.getTargetCoreId()); + if (!sourceCore || !targetCore || *sourceCore != localCore || *sourceCore >= *targetCore) + continue; + + Operation* matchingReceiveContainer = nullptr; + for (Operation* candidate = op.getNextNode(); candidate && candidate != body->getTerminator(); + candidate = candidate->getNextNode()) { + if (operationContainsReceiveFromPeer(*candidate, localCore, *targetCore, send.getInput().getType())) { + matchingReceiveContainer = candidate; + break; + } + } + + if (!matchingReceiveContainer) + continue; + + op.moveAfter(matchingReceiveContainer); + changed = true; break; - if (isa(&op)) - receives.push_back(&op); - } - - for (Operation* receive : receives) { - if (receive->getBlock() != &block) - continue; - - Operation* firstUse = findEarliestTopLevelUse(receive, block); - if (!firstUse || firstUse == receive || firstUse->getBlock() != &block) - continue; - - if (!receive->isBeforeInBlock(firstUse)) - continue; - - if (receive->getNextNode() == firstUse) - continue; - - receive->setAttr("raptor.receive_sunk_to_first_use", UnitAttr::get(receive->getContext())); - receive->moveBefore(firstUse); + } } } + return success(); } @@ -9479,23 +7564,15 @@ MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& sch if (failed(materializeInstanceSlot(state, instance))) return failure(); + if (failed(finalizeProjectedHostOutputFragments(state))) + return failure(); + if (failed(orderLowerCoreScalarSendsAfterMatchingReceives(state))) + return failure(); + for (MaterializedClass& materializedClass : state.classes) if (failed(localizeAllScheduledBodyCaptures(state, materializedClass))) return failure(); - if (failed(flushPendingProjectedHostReceives(state))) - return failure(); - - if (pimMaterializeScalarFanoutGlobalOrder) { - if (failed(sinkScalarReceivesToFirstUse(state))) - return failure(); - if (failed(reorderScalarCommunicationsByGlobalOrder(state))) - return failure(); - } - - if (failed(verifyMaterializedHostOutputs(state))) - return failure(); - replaceHostUses(state); if (failed(eraseOldComputeOps(state))) return failure(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.bkk b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.bkk new file mode 100644 index 0000000..82dbf03 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.bkk @@ -0,0 +1,9510 @@ +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/RegionUtils.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include +#include + +#include "MaterializeMergeSchedule.hpp" +#include "Scheduling/ComputeInstanceUtils.hpp" +#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" +#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace spatial { +namespace { + +using CpuId = size_t; +using ClassId = size_t; +using SlotId = size_t; + +static FailureOr getCheckedCoreId(Operation* anchor, CpuId cpu, StringRef fieldName) { + return pim::checkedI32(static_cast(cpu), anchor, fieldName); +} + +static FailureOr> +getCheckedCoreIds(Operation* anchor, ArrayRef cpus, StringRef fieldName) { + SmallVector coreIds; + coreIds.reserve(cpus.size()); + for (CpuId cpu : cpus) { + auto checkedCoreId = getCheckedCoreId(anchor, cpu, fieldName); + if (failed(checkedCoreId)) + return failure(); + coreIds.push_back(*checkedCoreId); + } + return coreIds; +} + +struct MessageVector { + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + size_t size() const { return channelIds.size(); } + bool empty() const { return channelIds.empty(); } + + LogicalResult verify(Operation* anchor) const { + if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) + return anchor->emitError("message metadata is inconsistent"); + return success(); + } + + void append(int64_t channelId, int32_t sourceCoreId, int32_t targetCoreId) { + channelIds.push_back(channelId); + sourceCoreIds.push_back(sourceCoreId); + targetCoreIds.push_back(targetCoreId); + } + + void append(ArrayRef channels, ArrayRef sources, ArrayRef targets) { + assert(channels.size() == sources.size() && "channel/source count mismatch"); + assert(channels.size() == targets.size() && "channel/target count mismatch"); + llvm::append_range(channelIds, channels); + llvm::append_range(sourceCoreIds, sources); + llvm::append_range(targetCoreIds, targets); + } + + MessageVector slice(size_t offset, size_t count) const { + MessageVector result; + result.append(ArrayRef(channelIds).slice(offset, count), + ArrayRef(sourceCoreIds).slice(offset, count), + ArrayRef(targetCoreIds).slice(offset, count)); + return result; + } +}; + +struct ProducerKey { + ComputeInstance instance; + size_t resultIndex = 0; + + bool operator==(const ProducerKey& other) const { + return instance == other.instance && resultIndex == other.resultIndex; + } +}; + +struct ProducerKeyInfo { + static ProducerKey getEmptyKey() { + return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max()}; + } + + static ProducerKey getTombstoneKey() { + return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max()}; + } + + static unsigned getHashValue(const ProducerKey& key) { + return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.instance), key.resultIndex); + } + + static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; } +}; + +struct SameClassConsumerLookupKey { + Operation* sourceOp = nullptr; + size_t resultIndex = 0; + ClassId classId = 0; + + bool operator==(const SameClassConsumerLookupKey& other) const { + return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId; + } +}; + +struct SameClassConsumerLookupKeyInfo { + static SameClassConsumerLookupKey getEmptyKey() { + return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max(), + std::numeric_limits::max()}; + } + + static SameClassConsumerLookupKey getTombstoneKey() { + return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max(), + std::numeric_limits::max()}; + } + + static unsigned getHashValue(const SameClassConsumerLookupKey& key) { + return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.sourceOp), key.resultIndex, key.classId); + } + + static bool isEqual(const SameClassConsumerLookupKey& lhs, const SameClassConsumerLookupKey& rhs) { + return lhs == rhs; + } +}; + +struct WholeBatchAssemblyLookupKey { + Operation* sourceOp = nullptr; + size_t resultIndex = 0; + ClassId classId = 0; + + bool operator==(const WholeBatchAssemblyLookupKey& other) const { + return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId; + } +}; + +struct WholeBatchAssemblyLookupKeyInfo { + static WholeBatchAssemblyLookupKey getEmptyKey() { + return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max(), + std::numeric_limits::max()}; + } + + static WholeBatchAssemblyLookupKey getTombstoneKey() { + return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max(), + std::numeric_limits::max()}; + } + + static unsigned getHashValue(const WholeBatchAssemblyLookupKey& key) { + return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.sourceOp), key.resultIndex, key.classId); + } + + static bool isEqual(const WholeBatchAssemblyLookupKey& lhs, const WholeBatchAssemblyLookupKey& rhs) { + return lhs == rhs; + } +}; + +using ClassSlotKey = std::pair; + +struct MaterializedClass { + ClassId id = 0; + SmallVector cpus; + Operation* op = nullptr; + Block* body = nullptr; + bool isBatch = false; + + DenseMap cpuToLane; + SmallVector weights; + SmallVector inputs; + SmallVector hostOutputs; + DenseMap weightArgs; + DenseMap inputArgs; + DenseMap hostOutputToResultIndex; +}; + +struct PackedScalarRunSlot { + SmallVector keys; +}; + +enum class PackedScalarRunKind { + Materialized, + DeferredReceive, + DeferredLocalCompute +}; + +struct PackedScalarRunValue { + ClassId targetClass = 0; + Operation* sourceOp = nullptr; + size_t resultIndex = 0; + PackedScalarRunKind kind = PackedScalarRunKind::Materialized; + + Value packed; + + RankedTensorType fragmentType; + SmallVector slots; + MessageVector messages; +}; + +struct IndexedBatchRunValue { + ClassId targetClass = 0; + Operation* sourceOp = nullptr; + size_t resultIndex = 0; + RankedTensorType fragmentType; + SmallVector slots; + MessageVector messages; +}; + +struct LogicalSlotRange { + SlotId start = 0; + SlotId count = 0; +}; + +struct MaterializationRunSlot { + SmallVector peers; +}; + +using MaterializationRun = SmallVector; + +struct OutputDestinationGroup { + SmallVector resultIndices; + SmallVector destinationClasses; +}; + +struct BatchRunSendPlan { + size_t resultIndex = 0; + ClassId destinationClass = 0; + MessageVector messages; +}; + +struct ProjectedBatchInputKey { + Operation* consumerOp = nullptr; + unsigned inputIndex = 0; + + bool operator==(const ProjectedBatchInputKey& other) const { + return consumerOp == other.consumerOp && inputIndex == other.inputIndex; + } +}; + +struct ProjectedBatchInputKeyInfo { + static ProjectedBatchInputKey getEmptyKey() { + return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max()}; + } + + static ProjectedBatchInputKey getTombstoneKey() { + return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max()}; + } + + static unsigned getHashValue(const ProjectedBatchInputKey& key) { + return llvm::hash_combine(key.consumerOp, key.inputIndex); + } + + static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { return lhs == rhs; } +}; + +struct ProjectedFragmentLayout { + RankedTensorType fragmentType; + SmallVector fragmentShape; + unsigned fragmentsPerLogicalSlot = 1; + unsigned payloadFragmentCount = 1; + SmallVector loopLowerBounds; + SmallVector loopSteps; + SmallVector loopTripCounts; +}; + +struct ProjectedTransferDescriptor { + ProjectedBatchInputKey inputKey; + Operation* extractOp = nullptr; + + ProjectedFragmentLayout layout; + RankedTensorType payloadType; + SmallVector, 16> fragmentOffsets; + SmallVector, 4> fragmentOffsetsByDim; +}; + +struct ProjectedExtractReplacement { + Value payload; + ProjectedFragmentLayout layout; +}; + +struct CloneIndexingContext { + std::optional runSlotIndex; + std::optional projectionSlotIndex; +}; + +struct StaticProjectedLoopInfo { + BlockArgument iv; + int64_t lowerBound = 0; + int64_t step = 1; + int64_t tripCount = 1; +}; + +struct AffineProjectedInputSliceMatch { + tensor::ExtractSliceOp extract; + RankedTensorType sourceType; + RankedTensorType fragmentType; + SmallVector fragmentShape; + SmallVector offsets; + SmallVector loops; +}; + +struct MaterializerState; + +struct PendingProjectedHostReceiveGroup { + Value originalOutput; + ClassId ownerClassId = 0; + RankedTensorType fragmentType; + SmallVector keys; + MessageVector messages; + Location loc; +}; + +struct PendingScalarReceiveRecord { + PendingScalarReceiveRecord(ArrayRef keys, + ClassId targetClassId, + Type receiveType, + const MessageVector& messages, + Location loc) + : targetClassId(targetClassId), + receiveType(receiveType), + messages(messages), + loc(loc) { + this->keys.append(keys.begin(), keys.end()); + } + + SmallVector keys; + ClassId targetClassId = 0; + Type receiveType; + MessageVector messages; + Location loc; + bool materialized = false; + Value value; +}; + +FailureOr materializeProjectedExtractReplacement(MaterializerState& state, + MaterializedClass& targetClass, + tensor::ExtractSliceOp extract, + const ProjectedExtractReplacement& replacement, + std::optional projectionSlotIndex, + IRMapping* mapper = nullptr); +FailureOr rematerializeTensorValueInClass(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef context, + IRMapping* mapper = nullptr); +FailureOr materializeTensorValueForMaterializedClassUse(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef context, + std::optional producer = std::nullopt, + IRMapping* mapper = nullptr); +FailureOr localizeMaterializedClassOperand(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef tensorContext, + StringRef genericContext, + IRMapping* mapper = nullptr); +LogicalResult localizeCapturesInClonedOp(MaterializerState& state, + MaterializedClass& targetClass, + Operation& clonedOp, + IRMapping* mapper = nullptr); +bool requiresConstantProjectionSlotIndex(MaterializerState& state, + MaterializedClass& targetClass, + Operation* sourceOp); +bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch, + const AffineProjectedInputSliceMatch& match, + ProducerKey producer, + uint32_t consumerLane); + +class AvailableValueStore { +public: + struct ExactBatchFragmentRecord { + ProducerKey key; + Value value; + }; + + void record(ProducerKey key, ClassId classId, Value value) { + exactValues[key][classId] = value; + + auto batch = dyn_cast_or_null(key.instance.op); + if (!batch || key.instance.laneCount == 0) + return; + + WholeBatchAssemblyLookupKey lookupKey {batch.getOperation(), key.resultIndex, classId}; + SmallVector& bucket = exactBatchFragmentsByProducerResultClass[lookupKey]; + for (ExactBatchFragmentRecord& record : bucket) { + if (!(record.key == key)) + continue; + record.value = value; + return; + } + bucket.push_back({key, value}); + } + + void recordPackedRun(PackedScalarRunValue run) { + size_t runIndex = packedScalarRuns.size(); + packedScalarRuns.push_back(std::move(run)); + const PackedScalarRunValue& storedRun = packedScalarRuns[runIndex]; + WholeBatchAssemblyLookupKey lookupKey {storedRun.sourceOp, storedRun.resultIndex, storedRun.targetClass}; + packedRunsByProducerResultClass[lookupKey].push_back(runIndex); + } + void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); } + + std::optional lookupExact(ProducerKey key, ClassId classId) const; + + std::optional lookup(MaterializerState& state, ProducerKey key, ClassId classId); + IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId); + + ArrayRef getPackedRunIndicesForWholeBatch(WholeBatchAssemblyLookupKey key) const { + auto it = packedRunsByProducerResultClass.find(key); + if (it == packedRunsByProducerResultClass.end()) + return {}; + return it->second; + } + + ArrayRef getExactFragmentsForWholeBatch(WholeBatchAssemblyLookupKey key) const { + auto it = exactBatchFragmentsByProducerResultClass.find(key); + if (it == exactBatchFragmentsByProducerResultClass.end()) + return {}; + return it->second; + } + + PackedScalarRunValue& getPackedRun(size_t index) { return packedScalarRuns[index]; } + +private: + std::optional lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId); + + DenseMap, ProducerKeyInfo> exactValues; + SmallVector packedScalarRuns; + SmallVector indexedBatchRuns; + DenseMap, WholeBatchAssemblyLookupKeyInfo> + exactBatchFragmentsByProducerResultClass; + DenseMap, WholeBatchAssemblyLookupKeyInfo> + packedRunsByProducerResultClass; +}; + +struct MaterializerState { + func::FuncOp func; + const MergeScheduleResult& schedule; + IRRewriter rewriter; + OperationFolder constantFolder; + int64_t& nextChannelId; + SmallVector classes; + DenseMap cpuToClass; + DenseMap> logicalInstancesByCpu; + DenseMap scheduledInstanceToLogicalSlots; + DenseMap logicalInstanceToScheduledChunk; + DenseSet materializedLogicalSlots; + + DenseMap, ProducerKeyInfo> producerDestClasses; + DenseMap, SameClassConsumerLookupKeyInfo> + sameClassConsumerIndex; + DenseMap projectedInputMatches; + DenseSet nonProjectedInputs; + DenseMap liveExternalUseCache; + DenseMap> batchOutputFragmentTypesCache; + DenseMap, llvm::DenseMapInfo> computeInstanceOutputsCache; + DenseMap, ProducerKeyInfo> projectedTransfers; + DenseMap> projectedExtractReplacements; + AvailableValueStore availableValues; + DenseMap hostReplacements; + DenseMap hostOutputOwners; + SmallVector pendingProjectedHostReceives; + SmallVector pendingScalarReceives; + DenseMap, ProducerKeyInfo> pendingScalarReceiveLookup; + DenseMap firstLateCommunicationOps; + int64_t nextCommunicationTraceId = 0; + DenseSet oldComputeOps; + + MaterializerState(func::FuncOp func, + const MergeScheduleResult& schedule, + int64_t& nextChannelId) + : func(func), + schedule(schedule), + rewriter(func.getContext()), + constantFolder(func.getContext()), + nextChannelId(nextChannelId) {} +}; + +bool isConstantLike(Value value) { + Operation* definingOp = value.getDefiningOp(); + return definingOp && definingOp->hasTrait(); +} + +bool isInsideOldCompute(Operation* op, const DenseSet& oldComputeOps) { + for (Operation* current = op; current; current = current->getParentOp()) + if (oldComputeOps.contains(current)) + return true; + return false; +} + +bool hasLiveExternalUse(Value value, const DenseSet& oldComputeOps); +ArrayRef getComputeInstanceOutputValuesCached(MaterializerState& state, ComputeInstance instance); + +bool hasLiveExternalUseCached(MaterializerState& state, Value value) { + auto cached = state.liveExternalUseCache.find(value); + if (cached != state.liveExternalUseCache.end()) + return cached->second; + bool live = hasLiveExternalUse(value, state.oldComputeOps); + state.liveExternalUseCache[value] = live; + return live; +} + +std::optional getConstantFirstSliceOffset(tensor::ExtractSliceOp extract) { + if (extract.getMixedOffsets().empty()) + return std::nullopt; + + OpFoldResult offset = extract.getMixedOffsets().front(); + if (auto attr = dyn_cast(offset)) { + auto intAttr = dyn_cast(attr); + if (!intAttr || intAttr.getInt() < 0) + return std::nullopt; + return static_cast(intAttr.getInt()); + } + + auto value = cast(offset); + if (auto constantIndex = value.getDefiningOp()) { + if (constantIndex.value() < 0) + return std::nullopt; + return static_cast(constantIndex.value()); + } + + APInt constantValue; + if (matchPattern(value, m_ConstantInt(&constantValue))) { + if (constantValue.isNegative()) + return std::nullopt; + return static_cast(constantValue.getZExtValue()); + } + + return std::nullopt; +} + +ProducerKey +getBatchLaneProducerKey(SpatComputeBatch batch, uint32_t laneStart, uint32_t laneCount, size_t resultIndex) { + return { + {batch.getOperation(), laneStart, laneCount}, + resultIndex + }; +} + +ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex) { + return getBatchLaneProducerKey(batch, 0, static_cast(batch.getLaneCount()), resultIndex); +} + +bool isWholeBatchProducerKey(ProducerKey key) { + auto batch = dyn_cast_or_null(key.instance.op); + return batch && batch.getNumResults() != 0 && key.instance.laneStart == 0 + && key.instance.laneCount == static_cast(batch.getLaneCount()); +} + +std::optional getContiguousProducerRangeForKeys(ArrayRef keys) { + if (keys.empty()) + return std::nullopt; + + ProducerKey first = keys.front(); + auto batch = dyn_cast_or_null(first.instance.op); + if (!batch) + return std::nullopt; + + SmallVector sorted(keys.begin(), keys.end()); + llvm::sort(sorted, [](ProducerKey lhs, ProducerKey rhs) { + return std::tie(lhs.instance.laneStart, lhs.instance.laneCount, lhs.resultIndex) + < std::tie(rhs.instance.laneStart, rhs.instance.laneCount, rhs.resultIndex); + }); + + uint32_t laneStart = sorted.front().instance.laneStart; + uint32_t nextLane = laneStart; + for (ProducerKey key : sorted) { + if (key.instance.op != first.instance.op || key.resultIndex != first.resultIndex || key.instance.laneCount == 0) + return std::nullopt; + if (key.instance.laneStart != nextLane) + return std::nullopt; + nextLane += key.instance.laneCount; + } + + uint32_t laneCount = nextLane - laneStart; + if (laneStart + laneCount > static_cast(batch.getLaneCount())) + return std::nullopt; + + return getBatchLaneProducerKey(batch, laneStart, laneCount, first.resultIndex); +} + +std::optional getPhysicallyContiguousProducerRangeForKeys(ArrayRef keys) { + if (keys.empty()) + return std::nullopt; + + ProducerKey first = keys.front(); + auto batch = dyn_cast_or_null(first.instance.op); + if (!batch || first.instance.laneCount == 0) + return std::nullopt; + + uint32_t laneStart = first.instance.laneStart; + uint32_t nextLane = laneStart; + for (ProducerKey key : keys) { + if (key.instance.op != first.instance.op || key.resultIndex != first.resultIndex || key.instance.laneCount == 0) + return std::nullopt; + if (key.instance.laneStart != nextLane) + return std::nullopt; + nextLane += key.instance.laneCount; + } + + uint32_t laneCount = nextLane - laneStart; + if (laneStart + laneCount > static_cast(batch.getLaneCount())) + return std::nullopt; + + return getBatchLaneProducerKey(batch, laneStart, laneCount, first.resultIndex); +} + +WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(Operation* sourceOp, size_t resultIndex, ClassId classId) { + return {sourceOp, resultIndex, classId}; +} + +WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(ProducerKey key, ClassId classId) { + return makeWholeBatchAssemblyLookupKey(key.instance.op, key.resultIndex, classId); +} + +FailureOr getPackedBatchTensorType(Type laneType, size_t laneCount) { + auto tensorType = dyn_cast(laneType); + if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0) + return failure(); + + SmallVector shape(tensorType.getShape()); + shape[0] *= static_cast(laneCount); + return RankedTensorType::get(shape, tensorType.getElementType(), tensorType.getEncoding()); +} + +LogicalResult verifyPackableFragmentType(Operation* anchor, Type fragmentType, size_t count, StringRef message) { + if (failed(getPackedBatchTensorType(fragmentType, count))) + return anchor->emitError(message); + return success(); +} + +ComputeInstance getScheduledChunkForLogicalInstance(MaterializerState& state, ComputeInstance logicalInstance) { + auto it = state.logicalInstanceToScheduledChunk.find(logicalInstance); + if (it != state.logicalInstanceToScheduledChunk.end()) + return it->second; + return logicalInstance; +} + +SmallVector +collectProducerKeysForDestinations(Value value, std::optional logicalConsumer = std::nullopt) { + // Destination collection works in the materializer's logical one-lane key domain. + // Whole-batch resultful producers are expanded into per-lane producer keys here. + SmallVector keys; + Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return keys; + + while (auto extract = dyn_cast(definingOp)) { + Value source = extract.getSource(); + auto batch = dyn_cast_or_null(source.getDefiningOp()); + if (batch && batch.getNumResults() != 0) { + auto result = dyn_cast(source); + if (!result) + return {}; + + if (std::optional lane = getConstantFirstSliceOffset(extract)) { + if (*lane >= static_cast(batch.getLaneCount())) + return {}; + keys.push_back(getBatchLaneProducerKey(batch, *lane, 1, result.getResultNumber())); + return keys; + } + + return {}; + } + + value = source; + definingOp = value.getDefiningOp(); + if (!definingOp) + return {}; + } + + if (auto compute = dyn_cast(definingOp)) { + auto result = dyn_cast(value); + if (!result) + return {}; + keys.push_back({ + {compute.getOperation(), 0, 1}, + result.getResultNumber() + }); + return keys; + } + + if (auto batch = dyn_cast(definingOp)) { + auto result = dyn_cast(value); + if (!result) + return {}; + + if (batch.getNumResults() != 0) { + for (uint32_t lane = 0; lane < static_cast(batch.getLaneCount()); ++lane) + keys.push_back(getBatchLaneProducerKey(batch, lane, 1, result.getResultNumber())); + return keys; + } + + ComputeInstance chunk = getBatchChunkForLane(batch, result.getResultNumber()); + keys.push_back({chunk, static_cast(result.getResultNumber() - chunk.laneStart)}); + return keys; + } + + return keys; +} + +std::optional getInputRequestProducerKey(Value value, + std::optional logicalConsumer = std::nullopt) { + // Input resolution may request a whole-batch key for scalar consumers that read + // a complete resultful compute_batch value. + Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return std::nullopt; + + while (auto extract = dyn_cast(definingOp)) { + Value source = extract.getSource(); + auto batch = dyn_cast_or_null(source.getDefiningOp()); + if (batch && batch.getNumResults() != 0) { + auto result = dyn_cast(source); + if (!result) + return std::nullopt; + + if (std::optional lane = getConstantFirstSliceOffset(extract)) + return getBatchLaneProducerKey(batch, *lane, 1, result.getResultNumber()); + + return std::nullopt; + } + + value = source; + definingOp = value.getDefiningOp(); + if (!definingOp) + return std::nullopt; + } + + if (auto compute = dyn_cast(definingOp)) { + auto result = dyn_cast(value); + if (!result) + return std::nullopt; + return ProducerKey { + {compute.getOperation(), 0, 1}, + result.getResultNumber() + }; + } + + if (auto batch = dyn_cast(definingOp)) { + auto result = dyn_cast(value); + if (!result) + return std::nullopt; + + if (batch.getNumResults() != 0) + return getWholeBatchProducerKey(batch, result.getResultNumber()); + + return ProducerKey {getBatchChunkForLane(batch, result.getResultNumber()), 0}; + } + + return std::nullopt; +} + +class CpuUnionFind { +public: + void insert(CpuId cpu) { parent.try_emplace(cpu, cpu); } + + CpuId find(CpuId cpu) { + insert(cpu); + CpuId p = parent.lookup(cpu); + if (p == cpu) + return cpu; + CpuId root = find(p); + parent[cpu] = root; + return root; + } + + void unite(CpuId lhs, CpuId rhs) { + CpuId lhsRoot = find(lhs); + CpuId rhsRoot = find(rhs); + if (lhsRoot == rhsRoot) + return; + if (rhsRoot < lhsRoot) + std::swap(lhsRoot, rhsRoot); + parent[rhsRoot] = lhsRoot; + } + +private: + DenseMap parent; +}; + +LogicalResult buildMaterializationWorkStreams(MaterializerState& state) { + DenseMap> scheduledInstancesByCpu; + for (const auto& [instance, cpu] : state.schedule.computeToCpuMap) { + state.oldComputeOps.insert(instance.op); + scheduledInstancesByCpu[cpu].push_back(instance); + state.logicalInstancesByCpu.try_emplace(cpu); + } + + for (auto& [cpu, scheduledInstances] : scheduledInstancesByCpu) { + llvm::sort(scheduledInstances, [&](const ComputeInstance& lhs, const ComputeInstance& rhs) { + auto lhsIt = state.schedule.computeToCpuSlotMap.find(lhs); + auto rhsIt = state.schedule.computeToCpuSlotMap.find(rhs); + assert(lhsIt != state.schedule.computeToCpuSlotMap.end() && "missing scheduler slot"); + assert(rhsIt != state.schedule.computeToCpuSlotMap.end() && "missing scheduler slot"); + return lhsIt->second < rhsIt->second; + }); + + SmallVector& logicalInstances = state.logicalInstancesByCpu[cpu]; + SlotId logicalSlot = 0; + for (const ComputeInstance& instance : scheduledInstances) { + LogicalSlotRange range {logicalSlot, 1}; + if (isa(instance.op)) + range.count = instance.laneCount; + + state.scheduledInstanceToLogicalSlots[instance] = range; + + if (isa(instance.op)) { + for (uint32_t localLane = 0; localLane < instance.laneCount; ++localLane, ++logicalSlot) { + uint32_t logicalLane = instance.laneStart + localLane; + ComputeInstance logicalInstance {instance.op, logicalLane, 1}; + logicalInstances.push_back(logicalInstance); + state.logicalInstanceToScheduledChunk[logicalInstance] = instance; + } + continue; + } + + logicalInstances.push_back(instance); + ++logicalSlot; + } + } + + return success(); +} + +LogicalResult buildMaterializationClassesFromScheduleEquivalence(MaterializerState& state) { + DenseSet usedCpus; + for (const auto& entry : state.schedule.cpuToLastComputeMap) + usedCpus.insert(entry.first); + for (const auto& entry : state.schedule.computeToCpuMap) + usedCpus.insert(entry.second); + + CpuUnionFind unionFind; + for (CpuId cpu : usedCpus) + unionFind.insert(cpu); + + for (const auto& [cpu, equivalentCpus] : state.schedule.equivalentClass) { + if (!usedCpus.contains(cpu)) + continue; + for (CpuId equivalentCpu : equivalentCpus) + if (usedCpus.contains(equivalentCpu)) + unionFind.unite(cpu, equivalentCpu); + } + + DenseMap> groupsByRoot; + for (CpuId cpu : usedCpus) + groupsByRoot[unionFind.find(cpu)].push_back(cpu); + + SmallVector roots; + roots.reserve(groupsByRoot.size()); + for (const auto& entry : groupsByRoot) + roots.push_back(entry.first); + llvm::sort(roots); + + state.classes.reserve(roots.size()); + for (CpuId root : roots) { + MaterializedClass materializedClass; + materializedClass.id = state.classes.size(); + materializedClass.cpus = groupsByRoot.lookup(root); + llvm::sort(materializedClass.cpus); + materializedClass.isBatch = materializedClass.cpus.size() > 1; + for (auto [lane, cpu] : llvm::enumerate(materializedClass.cpus)) { + materializedClass.cpuToLane[cpu] = static_cast(lane); + state.cpuToClass[cpu] = materializedClass.id; + } + state.classes.push_back(std::move(materializedClass)); + } + + return success(); +} + +LogicalResult verifyScheduleEquivalenceMatchesLogicalStreams(MaterializerState& state) { + for (const MaterializedClass& materializedClass : state.classes) { + if (materializedClass.cpus.empty()) + continue; + + auto referenceIt = state.logicalInstancesByCpu.find(materializedClass.cpus.front()); + if (referenceIt == state.logicalInstancesByCpu.end()) + return state.func.emitError("missing logical stream for materialized class reference CPU"); + + ArrayRef referenceStream(referenceIt->second); + for (CpuId cpu : materializedClass.cpus) { + auto streamIt = state.logicalInstancesByCpu.find(cpu); + if (streamIt == state.logicalInstancesByCpu.end()) + return state.func.emitError("missing logical stream for materialized class CPU"); + + ArrayRef stream(streamIt->second); + if (stream.size() != referenceStream.size()) + return state.func.emitError("materialized class CPUs have mismatched logical stream lengths"); + + for (auto [slot, zipped] : llvm::enumerate(llvm::zip(referenceStream, stream))) { + const ComputeInstance& referenceInstance = std::get<0>(zipped); + const ComputeInstance& currentInstance = std::get<1>(zipped); + if (referenceInstance.op != currentInstance.op) + return state.func.emitError("materialized class logical slot source op mismatch"); + if (isa(referenceInstance.op) != isa(currentInstance.op)) + return state.func.emitError("materialized class logical slot batch/scalar mismatch"); + (void) slot; + } + } + } + + return success(); +} + +LogicalResult forEachLogicalConsumerInMaterializationOrder( + MaterializerState& state, + llvm::function_ref + callback) { + for (const ComputeInstance& scheduledInstance : state.schedule.dominanceOrderCompute) { + auto cpuIt = state.schedule.computeToCpuMap.find(scheduledInstance); + if (cpuIt == state.schedule.computeToCpuMap.end()) + return scheduledInstance.op->emitError("missing CPU assignment for scheduled logical-slot iteration"); + + auto rangeIt = state.scheduledInstanceToLogicalSlots.find(scheduledInstance); + if (rangeIt == state.scheduledInstanceToLogicalSlots.end()) + return scheduledInstance.op->emitError("missing logical slot range for scheduled logical-slot iteration"); + + CpuId cpu = cpuIt->second; + ClassId classId = state.cpuToClass.lookup(cpu); + LogicalSlotRange range = rangeIt->second; + auto streamIt = state.logicalInstancesByCpu.find(cpu); + if (streamIt == state.logicalInstancesByCpu.end()) + return scheduledInstance.op->emitError("missing logical stream for CPU"); + for (SlotId logicalSlot = range.start; logicalSlot < range.start + range.count; ++logicalSlot) { + if (logicalSlot >= streamIt->second.size()) + return scheduledInstance.op->emitError("missing logical slot materialization instance"); + if (failed(callback(cpu, classId, scheduledInstance, streamIt->second[logicalSlot], logicalSlot))) + return failure(); + } + } + + return success(); +} + +bool isTerminalHostBatchOutput(Value output, const DenseSet& oldComputeOps); + +LogicalResult collectHostOutputs(MaterializerState& state) { + DenseSet seenOutputs; + SmallVector orderedOutputs; + DenseMap preferredOwners; + + for (const ComputeInstance& instance : state.schedule.dominanceOrderCompute) { + auto cpuIt = state.schedule.computeToCpuMap.find(instance); + if (cpuIt == state.schedule.computeToCpuMap.end()) + return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); + + ClassId classId = state.cpuToClass.lookup(cpuIt->second); + MaterializedClass& materializedClass = state.classes[classId]; + for (Value output : getComputeInstanceOutputValuesCached(state, instance)) { + if (!hasLiveExternalUseCached(state, output)) + continue; + + if (seenOutputs.insert(output).second) { + orderedOutputs.push_back(output); + preferredOwners[output] = classId; + continue; + } + + auto batch = dyn_cast_or_null(output.getDefiningOp()); + if (!batch || batch.getNumResults() == 0) + continue; + + ClassId currentOwner = preferredOwners.lookup(output); + bool terminalHost = isTerminalHostBatchOutput(output, state.oldComputeOps); + if (terminalHost) { + // Terminal resultful batch outputs are still published through scalar + // host-output slots unless the materialized batch class owns the output + // directly. Selecting an arbitrary batch class as the host owner would + // require a projection-aware batch publication path, which the + // materializer does not currently implement. + if (state.classes[currentOwner].isBatch && !materializedClass.isBatch) + preferredOwners[output] = classId; + continue; + } + + if (state.classes[currentOwner].isBatch && !materializedClass.isBatch) + preferredOwners[output] = classId; + } + } + + for (MaterializedClass& materializedClass : state.classes) { + materializedClass.hostOutputs.clear(); + materializedClass.hostOutputToResultIndex.clear(); + } + state.hostOutputOwners.clear(); + + for (Value output : orderedOutputs) { + ClassId ownerClassId = preferredOwners.lookup(output); + MaterializedClass& ownerClass = state.classes[ownerClassId]; + ownerClass.hostOutputToResultIndex[output] = ownerClass.hostOutputs.size(); + ownerClass.hostOutputs.push_back(output); + state.hostOutputOwners[output] = ownerClassId; + } + + return success(); +} + +LogicalResult createEmptyMaterializedOps(MaterializerState& state) { + Location loc = state.func.getLoc(); + Block& funcBlock = state.func.getBody().front(); + + Operation* firstOldCompute = nullptr; + for (Operation& op : funcBlock) { + if (state.oldComputeOps.contains(&op)) { + firstOldCompute = &op; + break; + } + } + + if (firstOldCompute) + state.rewriter.setInsertionPoint(firstOldCompute); + else + state.rewriter.setInsertionPointToStart(&funcBlock); + + for (MaterializedClass& materializedClass : state.classes) { + SmallVector resultTypes; + resultTypes.reserve(materializedClass.hostOutputs.size()); + for (Value output : materializedClass.hostOutputs) + resultTypes.push_back(output.getType()); + + if (!materializedClass.isBatch) { + auto compute = SpatScheduledCompute::create(state.rewriter, loc, TypeRange(resultTypes), ValueRange {}, ValueRange {}); + compute.getProperties().setOperandSegmentSizes({0, 0}); + auto coreIdAttr = + pim::getCheckedI32Attr(state.rewriter, state.func, materializedClass.cpus.front(), "materialized core id"); + if (failed(coreIdAttr)) + return failure(); + compute->setAttr(onnx_mlir::kCoreIdAttrName, *coreIdAttr); + Block* body = state.rewriter.createBlock(&compute.getBody()); + state.rewriter.setInsertionPointToEnd(body); + SmallVector placeholderOutputs; + placeholderOutputs.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto tensorType = dyn_cast(resultType); + if (!tensorType || !tensorType.hasStaticShape()) { + compute.emitOpError("host-facing materialized compute results must be static ranked tensors"); + return failure(); + } + placeholderOutputs.push_back( + tensor::EmptyOp::create(state.rewriter, loc, tensorType.getShape(), tensorType.getElementType()).getResult()); + } + SpatYieldOp::create(state.rewriter, loc, ValueRange(placeholderOutputs)); + materializedClass.op = compute.getOperation(); + materializedClass.body = body; + state.rewriter.setInsertionPointAfter(compute.getOperation()); + continue; + } + + auto batchLaneCountAttr = pim::getCheckedI32Attr( + state.rewriter, state.func, materializedClass.cpus.size(), "materialized batch lane count"); + if (failed(batchLaneCountAttr)) + return failure(); + auto batch = SpatScheduledComputeBatch::create(state.rewriter, loc, TypeRange(resultTypes), *batchLaneCountAttr, ValueRange {}, ValueRange {}); + batch.getProperties().setOperandSegmentSizes({0, 0}); + auto coreIds = getCheckedCoreIds(state.func, materializedClass.cpus, "materialized batch core id"); + if (failed(coreIds)) + return failure(); + batch->setAttr(onnx_mlir::kCoreIdsAttrName, state.rewriter.getDenseI32ArrayAttr(*coreIds)); + + SmallVector blockArgTypes {state.rewriter.getIndexType()}; + SmallVector blockArgLocs {loc}; + llvm::append_range(blockArgTypes, resultTypes); + blockArgLocs.append(resultTypes.size(), loc); + Block* body = + state.rewriter.createBlock(&batch.getBody(), batch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + state.rewriter.setInsertionPointToEnd(body); + if (resultTypes.empty()) + SpatYieldOp::create(state.rewriter, loc, ValueRange {}); + else + SpatInParallelOp::create(state.rewriter, loc); + materializedClass.op = batch.getOperation(); + materializedClass.body = body; + state.rewriter.setInsertionPointAfter(batch.getOperation()); + } + + return success(); +} + +BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) { + auto it = materializedClass.weightArgs.find(weight); + if (it != materializedClass.weightArgs.end()) + return it->second; + + unsigned weightIndex = materializedClass.weights.size(); + materializedClass.weights.push_back(weight); + + if (auto compute = dyn_cast(materializedClass.op)) { + auto arg = compute.insertWeight(weightIndex, weight, weight.getLoc()); + assert(arg && "expected compute body while inserting a weight"); + materializedClass.weightArgs[weight] = std::get<1>(*arg); + return std::get<1>(*arg); + } + + auto batch = cast(materializedClass.op); + auto arg = batch.insertWeight(weightIndex, weight, weight.getLoc()); + assert(arg && "expected compute_batch body while inserting a weight argument"); + materializedClass.weightArgs[weight] = std::get<1>(*arg); + return std::get<1>(*arg); +} + +BlockArgument appendInput(MaterializerState& state, MaterializedClass& materializedClass, Value input) { + auto it = materializedClass.inputArgs.find(input); + if (it != materializedClass.inputArgs.end()) + return it->second; + + materializedClass.inputs.push_back(input); + if (auto compute = dyn_cast(materializedClass.op)) { + auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc()); + assert(arg && "expected compute body while inserting an input"); + materializedClass.inputArgs[input] = std::get<1>(*arg); + return std::get<1>(*arg); + } + if (auto compute = dyn_cast(materializedClass.op)) { + auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc()); + assert(arg && "expected compute_batch body while inserting an input argument"); + materializedClass.inputArgs[input] = std::get<1>(*arg); + return std::get<1>(*arg); + } + llvm_unreachable("Cannot reach here"); +} + +Region* getParentRegion(Value value) { + if (auto blockArg = dyn_cast(value)) + return blockArg.getOwner()->getParent(); + if (Operation* definingOp = value.getDefiningOp()) + return definingOp->getParentRegion(); + return nullptr; +} + +bool isDefinedInsideRegion(Value value, Region& region) { + Region* parentRegion = getParentRegion(value); + return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion)); +} + +Operation* getEnclosingSpatialComputeLikeOp(Value value) { + Block* block = nullptr; + if (auto blockArg = dyn_cast(value)) + block = blockArg.getOwner(); + else if (Operation* definingOp = value.getDefiningOp()) + block = definingOp->getBlock(); + + if (!block) + return nullptr; + + for (Operation* current = block->getParentOp(); current; current = current->getParentOp()) + if (isa(current)) + return current; + return nullptr; +} + +bool isTensorValueLocalToMaterializedClass(Value value, const MaterializedClass& targetClass) { + if (!isa(value.getType())) + return true; + if (isConstantLike(value)) + return true; + + Region& targetRegion = *targetClass.body->getParent(); + return isDefinedInsideRegion(value, targetRegion); +} + +bool isTensorValueDefinedInDifferentMaterializedClass(Value value, const MaterializedClass& targetClass) { + if (!isa(value.getType()) || isTensorValueLocalToMaterializedClass(value, targetClass)) + return false; + + Operation* owner = getEnclosingSpatialComputeLikeOp(value); + return owner && owner != targetClass.op; +} + +std::optional getRegionIndexInParentOp(Region* region) { + Operation* parent = region ? region->getParentOp() : nullptr; + if (!parent) + return std::nullopt; + + for (auto [index, candidate] : llvm::enumerate(parent->getRegions())) + if (&candidate == region) + return static_cast(index); + return std::nullopt; +} + +std::optional getBlockIndexInRegion(Block* block) { + Region* region = block ? block->getParent() : nullptr; + if (!region) + return std::nullopt; + + for (auto [index, candidate] : llvm::enumerate(region->getBlocks())) + if (&candidate == block) + return static_cast(index); + return std::nullopt; +} + +Block* getBlockByIndex(Region& region, unsigned blockIndex) { + unsigned index = 0; + for (Block& block : region) { + if (index == blockIndex) + return █ + ++index; + } + return nullptr; +} + +static bool isValueLegalInMaterializedClassBody(Value value, const MaterializedClass& targetClass) { + if (isConstantLike(value)) + return true; + + Region& targetRegion = *targetClass.body->getParent(); + return isDefinedInsideRegion(value, targetRegion); +} + +std::string stringifyOperationForMaterializerDebug(Operation* op) { + if (!op) + return std::string(""); + std::string storage; + llvm::raw_string_ostream stream(storage); + op->print(stream); + return storage; +} + +std::string stringifyValueForMaterializerDebug(Value value) { + std::string storage; + llvm::raw_string_ostream stream(storage); + value.print(stream); + return storage; +} + +std::string truncateMaterializerDebugString(std::string text, size_t limit = 1200) { + for (char& ch : text) + if (ch == '\n' || ch == '\r' || ch == '\t') + ch = ' '; + + if (text.size() <= limit) + return text; + text.resize(limit); + text += "..."; + return text; +} + +std::string formatMaterializerOperandListInline(Operation* op, const MaterializedClass& targetClass) { + if (!op) + return std::string(""); + + std::string storage; + llvm::raw_string_ostream stream(storage); + for (OpOperand& operand : op->getOpOperands()) { + if (operand.getOperandNumber() != 0) + stream << " | "; + Value value = operand.get(); + stream << "operand#" << operand.getOperandNumber() << " type=" << value.getType() + << " local=" << (isValueLegalInMaterializedClassBody(value, targetClass) ? 1 : 0) + << " value=" << stringifyValueForMaterializerDebug(value); + if (auto blockArg = dyn_cast(value)) { + stream << " blockArg#" << blockArg.getArgNumber(); + if (Operation* owner = blockArg.getOwner()->getParentOp()) + stream << " ownerOp='" << owner->getName() << "'"; + } else if (Operation* definingOp = value.getDefiningOp()) { + stream << " definingOp='" << definingOp->getName() << "'"; + } + } + return truncateMaterializerDebugString(stream.str()); +} + +std::string formatMaterializerParentChainInline(Operation* op) { + if (!op) + return std::string(""); + + std::string storage; + llvm::raw_string_ostream stream(storage); + unsigned depth = 0; + for (Operation* current = op; current; current = current->getParentOp()) { + if (depth != 0) + stream << " <- "; + stream << "[" << depth++ << "]" << current->getName(); + } + return truncateMaterializerDebugString(stream.str()); +} + +void attachMaterializerOperationPrintNote(InFlightDiagnostic& diagnostic, Operation* op, StringRef label) { + if (!op) + return; + diagnostic.attachNote(op->getLoc()) << label << ":\n" << stringifyOperationForMaterializerDebug(op); +} + +void attachMaterializerParentChainNote(InFlightDiagnostic& diagnostic, Operation* op, StringRef label) { + if (!op) + return; + + std::string storage; + llvm::raw_string_ostream stream(storage); + unsigned depth = 0; + for (Operation* current = op; current; current = current->getParentOp()) + stream << " [" << depth++ << "] " << current->getName() << "\n"; + + diagnostic.attachNote(op->getLoc()) << label << ":\n" << stream.str(); +} + +void attachMaterializerOperandListNote(InFlightDiagnostic& diagnostic, + Operation* op, + const MaterializedClass& targetClass, + StringRef label) { + if (!op) + return; + + std::string storage; + llvm::raw_string_ostream stream(storage); + for (OpOperand& operand : op->getOpOperands()) { + Value value = operand.get(); + stream << " operand#" << operand.getOperandNumber() << " type=" << value.getType() + << " local=" << (isValueLegalInMaterializedClassBody(value, targetClass) ? 1 : 0) + << " value=" << stringifyValueForMaterializerDebug(value); + if (auto blockArg = dyn_cast(value)) { + stream << " blockArg#" << blockArg.getArgNumber(); + if (Operation* owner = blockArg.getOwner()->getParentOp()) + stream << " ownerOp='" << owner->getName() << "'"; + } else if (Operation* definingOp = value.getDefiningOp()) { + stream << " definingOp='" << definingOp->getName() << "'"; + } + stream << "\n"; + } + + diagnostic.attachNote(op->getLoc()) << label << ":\n" << stream.str(); +} + +void attachMaterializerValueOriginNote(InFlightDiagnostic& diagnostic, Value value, StringRef label) { + if (auto blockArg = dyn_cast(value)) { + if (Operation* owner = blockArg.getOwner()->getParentOp()) + diagnostic.attachNote(owner->getLoc()) + << label << " is block argument #" << blockArg.getArgNumber() << " of '" << owner->getName() + << "' with type " << blockArg.getType(); + else + diagnostic.attachNote(UnknownLoc::get(value.getContext())) + << label << " is a top-level block argument #" << blockArg.getArgNumber() + << " with type " << blockArg.getType(); + return; + } + + if (Operation* definingOp = value.getDefiningOp()) { + diagnostic.attachNote(definingOp->getLoc()) + << label << " is defined by '" << definingOp->getName() << "' with result type " << value.getType(); + return; + } + + diagnostic.attachNote(UnknownLoc::get(value.getContext())) + << label << " has no defining operation and is not a block argument, type " << value.getType(); +} + +void attachMaterializedClassBodySummary(InFlightDiagnostic& diagnostic, const MaterializedClass& targetClass) { + Block& body = *targetClass.body; + diagnostic.attachNote(targetClass.op->getLoc()) + << "RAPTOR_MATERIALIZER_DEBUG target class " << targetClass.id << " op '" << targetClass.op->getName() + << "' body has " << body.getNumArguments() << " block arguments and " + << std::distance(body.begin(), body.end()) << " top-level operations"; +} + +FailureOr rematerializeIndexValueInClass(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Location loc, + IRMapping* mapper = nullptr); + +FailureOr rematerializeIndexOpFoldResultInClass(MaterializerState& state, + MaterializedClass& targetClass, + OpFoldResult value, + Location loc, + IRMapping* mapper = nullptr) { + if (auto attr = dyn_cast(value)) + return OpFoldResult(attr); + + FailureOr rematerialized = rematerializeIndexValueInClass(state, targetClass, cast(value), loc, mapper); + if (failed(rematerialized)) + return failure(); + return OpFoldResult(*rematerialized); +} + +FailureOr rematerializeIndexValueInClass(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Location loc, + IRMapping* mapper) { + Value originalValue = value; + bool mapperHadOriginalValue = false; + Value mappedOriginalValue; + + if (mapper && mapper->contains(value)) { + mapperHadOriginalValue = true; + Value mapped = mapper->lookup(value); + mappedOriginalValue = mapped; + if (isValueLegalInMaterializedClassBody(mapped, targetClass) || isConstantLike(mapped)) + return mapped; + value = mapped; + } + + if (isValueLegalInMaterializedClassBody(value, targetClass)) + return value; + + if (!value.getType().isIndex()) + return targetClass.op->emitError("cannot rematerialize non-index external value in materialized class body") + << " type=" << value.getType(); + + if (auto constantIndex = value.getDefiningOp()) + return getOrCreateIndexConstant(state.constantFolder, targetClass.op, constantIndex.value()); + + APInt constantValue; + if (matchPattern(value, m_ConstantInt(&constantValue))) { + if (!constantValue.isSignedIntN(64)) + return targetClass.op->emitError("cannot rematerialize out-of-range index constant") + << " value=" << llvm::toString(constantValue, 10, /*Signed=*/true); + return getOrCreateIndexConstant(state.constantFolder, targetClass.op, constantValue.getSExtValue()); + } + + if (auto affineApply = value.getDefiningOp()) { + SmallVector remappedOperands; + remappedOperands.reserve(affineApply.getMapOperands().size()); + for (Value operand : affineApply.getMapOperands()) { + FailureOr remapped = rematerializeIndexValueInClass(state, targetClass, operand, loc, mapper); + if (failed(remapped)) + return failure(); + remappedOperands.push_back(*remapped); + } + return createOrFoldAffineApply(state.rewriter, loc, affineApply.getAffineMap(), remappedOperands, state.func); + } + + if (auto addOp = value.getDefiningOp()) { + FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, addOp.getLhs(), loc, mapper); + FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, addOp.getRhs(), loc, mapper); + if (failed(lhs) || failed(rhs)) + return failure(); + return arith::AddIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); + } + + if (auto subOp = value.getDefiningOp()) { + FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, subOp.getLhs(), loc, mapper); + FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, subOp.getRhs(), loc, mapper); + if (failed(lhs) || failed(rhs)) + return failure(); + return arith::SubIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); + } + + if (auto mulOp = value.getDefiningOp()) { + FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, mulOp.getLhs(), loc, mapper); + FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, mulOp.getRhs(), loc, mapper); + if (failed(lhs) || failed(rhs)) + return failure(); + return arith::MulIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); + } + + if (auto divOp = value.getDefiningOp()) { + FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, divOp.getLhs(), loc, mapper); + FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, divOp.getRhs(), loc, mapper); + if (failed(lhs) || failed(rhs)) + return failure(); + return arith::DivUIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); + } + + if (auto extractOp = value.getDefiningOp()) { + SmallVector remappedIndices; + remappedIndices.reserve(extractOp.getIndices().size()); + for (Value index : extractOp.getIndices()) { + FailureOr remapped = rematerializeIndexValueInClass(state, targetClass, index, loc, mapper); + if (failed(remapped)) + return failure(); + remappedIndices.push_back(*remapped); + } + + Value tensor = extractOp.getTensor(); + if (!isConstantLike(tensor) && !isValueLegalInMaterializedClassBody(tensor, targetClass)) + return targetClass.op->emitError("cannot rematerialize indexed table lookup from external non-constant tensor") + << " tensorType=" << tensor.getType(); + return tensor::ExtractOp::create(state.rewriter, loc, tensor, remappedIndices).getResult(); + } + + if (auto blockArg = dyn_cast(value)) { + InFlightDiagnostic diagnostic = targetClass.op->emitError( + "RAPTOR_MATERIALIZER_DEBUG cannot rematerialize external block argument in materialized class body"); + diagnostic << " currentArg#" << blockArg.getArgNumber() << " currentType=" << blockArg.getType() + << " targetClass=" << targetClass.id << " targetOp='" << targetClass.op->getName() << "'"; + if (Operation* owner = blockArg.getOwner()->getParentOp()) { + diagnostic << " ownerOp='" << owner->getName() << "'"; + diagnostic << " ownerIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(owner)) << "\""; + diagnostic << " ownerChain=\"" << formatMaterializerParentChainInline(owner) << "\""; + } + diagnostic << " targetIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(targetClass.op)) << "\""; + if (mapper) { + diagnostic << " mapperPresent=1 mapperHadOriginal=" << (mapperHadOriginalValue ? 1 : 0); + if (mapperHadOriginalValue) + diagnostic << " mappedType=" << mappedOriginalValue.getType(); + } else { + diagnostic << " mapperPresent=0"; + } + attachMaterializerValueOriginNote(diagnostic, originalValue, "original value"); + if (value != originalValue) + attachMaterializerValueOriginNote(diagnostic, value, "mapped/current value"); + if (mapperHadOriginalValue && mappedOriginalValue != value) + attachMaterializerValueOriginNote(diagnostic, mappedOriginalValue, "mapper value"); + if (Operation* owner = blockArg.getOwner()->getParentOp()) { + attachMaterializerOperationPrintNote(diagnostic, owner, "RAPTOR_MATERIALIZER_DEBUG external block argument owner op"); + attachMaterializerParentChainNote(diagnostic, owner, "RAPTOR_MATERIALIZER_DEBUG external block argument owner parent chain"); + } + attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "RAPTOR_MATERIALIZER_DEBUG target materialized op"); + attachMaterializedClassBodySummary(diagnostic, targetClass); + return failure(); + } + + InFlightDiagnostic diagnostic = + targetClass.op->emitError("RAPTOR_MATERIALIZER_DEBUG cannot rematerialize external index value in materialized class body"); + diagnostic << " type=" << value.getType() << " targetClass=" << targetClass.id << " targetOp='" + << targetClass.op->getName() << "'"; + attachMaterializerValueOriginNote(diagnostic, originalValue, "original value"); + if (value != originalValue) + attachMaterializerValueOriginNote(diagnostic, value, "mapped/current value"); + attachMaterializedClassBodySummary(diagnostic, targetClass); + return failure(); +} + +InFlightDiagnostic emitNonLocalMaterializedClassValueDiagnostic(Operation* anchor, + const MaterializedClass& targetClass, + StringRef context, + Value value, + std::optional producer = std::nullopt) { + InFlightDiagnostic diagnostic = anchor->emitError(context) << " into target class " << targetClass.id; + + if (producer) { + diagnostic << " from '" << producer->instance.op->getName() << "' resultIndex=" << producer->resultIndex + << " laneStart=" << producer->instance.laneStart << " laneCount=" << producer->instance.laneCount; + } else if (auto result = dyn_cast(value)) { + diagnostic << " from '" << result.getOwner()->getName() << "' resultIndex=" << result.getResultNumber(); + } else if (auto blockArg = dyn_cast(value)) { + diagnostic << " from block argument #" << blockArg.getArgNumber(); + if (Operation* owner = blockArg.getOwner()->getParentOp()) + diagnostic << " of '" << owner->getName() << "'"; + } + + if (Operation* definingOp = value.getDefiningOp()) + diagnostic.attachNote(definingOp->getLoc()) << "offending tensor producer is '" << definingOp->getName() << "'"; + return diagnostic; +} + +FailureOr rematerializeTensorValueInClass(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef context, + IRMapping* mapper) { + auto extractSlice = value.getDefiningOp(); + if (extractSlice) { + FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( + state, targetClass, extractSlice.getSource(), anchor, context, std::nullopt, mapper); + if (failed(localizedSource)) + return failure(); + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(extractSlice.getMixedOffsets().size()); + sizes.reserve(extractSlice.getMixedSizes().size()); + strides.reserve(extractSlice.getMixedStrides().size()); + + for (OpFoldResult offset : extractSlice.getMixedOffsets()) { + FailureOr localized = + rematerializeIndexOpFoldResultInClass(state, targetClass, offset, anchor->getLoc(), mapper); + if (failed(localized)) + return failure(); + offsets.push_back(*localized); + } + for (OpFoldResult size : extractSlice.getMixedSizes()) { + FailureOr localized = + rematerializeIndexOpFoldResultInClass(state, targetClass, size, anchor->getLoc(), mapper); + if (failed(localized)) + return failure(); + sizes.push_back(*localized); + } + for (OpFoldResult stride : extractSlice.getMixedStrides()) { + FailureOr localized = + rematerializeIndexOpFoldResultInClass(state, targetClass, stride, anchor->getLoc(), mapper); + if (failed(localized)) + return failure(); + strides.push_back(*localized); + } + + return tensor::ExtractSliceOp::create(state.rewriter, anchor->getLoc(), *localizedSource, offsets, sizes, strides) + .getResult(); + } + + if (auto collapseShape = value.getDefiningOp()) { + FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( + state, targetClass, collapseShape.getSrc(), anchor, context, std::nullopt, mapper); + if (failed(localizedSource)) + return failure(); + return tensor::CollapseShapeOp::create( + state.rewriter, anchor->getLoc(), *localizedSource, collapseShape.getReassociationIndices()) + .getResult(); + } + + return failure(); +} + +FailureOr materializeTensorValueForMaterializedClassUse(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef context, + std::optional producer, + IRMapping* mapper) { + if (mapper && mapper->contains(value)) + value = mapper->lookup(value); + + if (!isa(value.getType()) || isConstantLike(value) || isTensorValueLocalToMaterializedClass(value, targetClass)) + return value; + + if (value.getDefiningOp() || value.getDefiningOp()) { + FailureOr rematerialized = rematerializeTensorValueInClass(state, targetClass, value, anchor, context, mapper); + if (failed(rematerialized)) + return failure(); + return *rematerialized; + } + + if (isTensorValueDefinedInDifferentMaterializedClass(value, targetClass)) { + emitNonLocalMaterializedClassValueDiagnostic(anchor, targetClass, context, value, producer); + return failure(); + } + + return appendInput(state, targetClass, value); +} + +std::optional mapExternalRegionBlockArgumentToLocalClone(const MaterializedClass& targetClass, + Operation* anchor, + BlockArgument externalArg) { + Block* sourceBlock = externalArg.getOwner(); + Region* sourceRegion = sourceBlock ? sourceBlock->getParent() : nullptr; + Operation* sourceParent = sourceRegion ? sourceRegion->getParentOp() : nullptr; + if (!sourceParent || !anchor) + return std::nullopt; + + std::optional sourceRegionIndex = getRegionIndexInParentOp(sourceRegion); + std::optional sourceBlockIndex = getBlockIndexInRegion(sourceBlock); + if (!sourceRegionIndex || !sourceBlockIndex) + return std::nullopt; + + for (Operation* current = anchor->getParentOp(); current && current != targetClass.op; + current = current->getParentOp()) { + if (current->getName() != sourceParent->getName()) + continue; + if (current->getNumRegions() <= *sourceRegionIndex) + continue; + + Region& localRegion = current->getRegion(*sourceRegionIndex); + Block* localBlock = getBlockByIndex(localRegion, *sourceBlockIndex); + if (!localBlock || localBlock->getNumArguments() <= externalArg.getArgNumber()) + continue; + + BlockArgument localArg = localBlock->getArgument(externalArg.getArgNumber()); + if (localArg.getType() != externalArg.getType()) + continue; + if (!isValueLegalInMaterializedClassBody(localArg, targetClass)) + continue; + return localArg; + } + + return std::nullopt; +} + +FailureOr localizeMaterializedClassOperand(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef tensorContext, + StringRef genericContext, + IRMapping* mapper) { + if (mapper && mapper->contains(value)) + value = mapper->lookup(value); + + if (auto blockArg = dyn_cast(value)) + if (std::optional localArg = mapExternalRegionBlockArgumentToLocalClone(targetClass, anchor, blockArg)) + return *localArg; + + if (isa(value.getType())) + return materializeTensorValueForMaterializedClassUse(state, targetClass, value, anchor, tensorContext, std::nullopt, mapper); + + if (isValueLegalInMaterializedClassBody(value, targetClass)) + return value; + + if (value.getType().isIndex()) + return rematerializeIndexValueInClass(state, targetClass, value, anchor->getLoc(), mapper); + + InFlightDiagnostic diagnostic = anchor->emitError(genericContext); + diagnostic << " type=" << value.getType(); + if (auto blockArg = dyn_cast(value)) { + diagnostic << " blockArg#" << blockArg.getArgNumber(); + if (Operation* owner = blockArg.getOwner()->getParentOp()) + diagnostic.attachNote(owner->getLoc()) << "block argument belongs to '" << owner->getName() << "'"; + } else if (Operation* definingOp = value.getDefiningOp()) { + diagnostic.attachNote(definingOp->getLoc()) << "unsupported external operand producer is '" << definingOp->getName() + << "'"; + } + return failure(); +} + +// ----------------------------------------------------------------------------- +// Tensor packing helpers. +// ----------------------------------------------------------------------------- + +struct Dim0SliceParams { + SmallVector offsets; + SmallVector sizes; + SmallVector strides; +}; + +Dim0SliceParams +buildDim0SliceParams(OpBuilder& builder, RankedTensorType referenceType, OpFoldResult firstOffset, int64_t firstSize) { + Dim0SliceParams params; + params.offsets.reserve(referenceType.getRank()); + params.sizes.reserve(referenceType.getRank()); + params.strides.reserve(referenceType.getRank()); + + params.offsets.push_back(firstOffset); + params.sizes.push_back(builder.getIndexAttr(firstSize)); + params.strides.push_back(builder.getIndexAttr(1)); + + for (int64_t dim = 1; dim < referenceType.getRank(); ++dim) { + params.offsets.push_back(builder.getIndexAttr(0)); + params.sizes.push_back(builder.getIndexAttr(referenceType.getDimSize(dim))); + params.strides.push_back(builder.getIndexAttr(1)); + } + + return params; +} + +Value createDim0ExtractSlice( + MaterializerState& state, Location loc, Value source, OpFoldResult firstOffset, int64_t firstSize) { + auto sourceType = cast(source.getType()); + Dim0SliceParams params = buildDim0SliceParams(state.rewriter, sourceType, firstOffset, firstSize); + return tensor::ExtractSliceOp::create(state.rewriter, loc, source, params.offsets, params.sizes, params.strides) + .getResult(); +} + +FailureOr createDim0ExtractSliceInClass(MaterializerState& state, + MaterializedClass& targetClass, + Location loc, + Value source, + OpFoldResult firstOffset, + int64_t firstSize) { + FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + source, + targetClass.op, + "createDim0ExtractSliceInClass tried to reuse a tensor from another materialized class"); + if (failed(localizedSource)) + return failure(); + FailureOr localizedOffset = + rematerializeIndexOpFoldResultInClass(state, targetClass, firstOffset, loc); + if (failed(localizedOffset)) + return failure(); + return createDim0ExtractSlice(state, loc, *localizedSource, *localizedOffset, firstSize); +} + +Value createStaticExtractSlice(MaterializerState& state, + Location loc, + Value source, + ArrayRef sliceOffsets, + ArrayRef resultShape) { + auto sourceType = cast(source.getType()); + assert(sliceOffsets.size() == static_cast(sourceType.getRank()) && "offset rank mismatch"); + assert(resultShape.size() == static_cast(sourceType.getRank()) && "result rank mismatch"); + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(sourceType.getRank()); + sizes.reserve(sourceType.getRank()); + strides.reserve(sourceType.getRank()); + + for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { + offsets.push_back(sliceOffsets[dim]); + sizes.push_back(state.rewriter.getIndexAttr(resultShape[dim])); + strides.push_back(state.rewriter.getIndexAttr(1)); + } + + return tensor::ExtractSliceOp::create(state.rewriter, loc, source, offsets, sizes, strides).getResult(); +} + +FailureOr createStaticExtractSliceInClass(MaterializerState& state, + MaterializedClass& targetClass, + Location loc, + Value source, + ArrayRef sliceOffsets, + ArrayRef resultShape) { + FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + source, + targetClass.op, + "createStaticExtractSliceInClass tried to reuse a tensor from another materialized class"); + if (failed(localizedSource)) + return failure(); + + SmallVector localizedOffsets; + localizedOffsets.reserve(sliceOffsets.size()); + for (OpFoldResult offset : sliceOffsets) { + FailureOr localized = + rematerializeIndexOpFoldResultInClass(state, targetClass, offset, loc); + if (failed(localized)) + return failure(); + localizedOffsets.push_back(*localized); + } + return createStaticExtractSlice(state, loc, *localizedSource, localizedOffsets, resultShape); +} + +Value createIndexedIndexValue(MaterializerState& state, + Operation* anchor, + ArrayRef values, + Value index, + Location loc, + std::optional preferredPeriod = std::nullopt, + bool allowExhaustiveTiledSearch = true); + +FailureOr> buildProjectedFragmentOffsetsInClass(MaterializerState& state, + MaterializedClass& targetClass, + const ProjectedTransferDescriptor& descriptor, + Value flatFragmentIndex, + Location loc) { + FailureOr localizedIndex = rematerializeIndexValueInClass(state, targetClass, flatFragmentIndex, loc); + if (failed(localizedIndex)) + return failure(); + SmallVector fragmentOffsets; + fragmentOffsets.reserve(descriptor.layout.fragmentShape.size()); + for (ArrayRef dimOffsets : descriptor.fragmentOffsetsByDim) + fragmentOffsets.push_back(createIndexedIndexValue(state, + targetClass.op, + dimOffsets, + *localizedIndex, + loc, + static_cast(descriptor.layout.payloadFragmentCount), + /*allowExhaustiveTiledSearch=*/false)); + return fragmentOffsets; +} + +Value createDim0InsertSlice( + MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) { + auto fragmentType = cast(fragment.getType()); + Dim0SliceParams params = buildDim0SliceParams(state.rewriter, fragmentType, firstOffset, fragmentType.getDimSize(0)); + return tensor::InsertSliceOp::create( + state.rewriter, loc, fragment, destination, params.offsets, params.sizes, params.strides) + .getResult(); +} + +FailureOr createDim0InsertSliceInClass(MaterializerState& state, + MaterializedClass& targetClass, + Location loc, + Value fragment, + Value destination, + OpFoldResult firstOffset) { + FailureOr localizedFragment = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + fragment, + targetClass.op, + "createDim0InsertSliceInClass tried to reuse a fragment tensor from another materialized class"); + if (failed(localizedFragment)) + return failure(); + FailureOr localizedDestination = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + destination, + targetClass.op, + "createDim0InsertSliceInClass tried to reuse a destination tensor from another materialized class"); + if (failed(localizedDestination)) + return failure(); + FailureOr localizedOffset = + rematerializeIndexOpFoldResultInClass(state, targetClass, firstOffset, loc); + if (failed(localizedOffset)) + return failure(); + return createDim0InsertSlice(state, loc, *localizedFragment, *localizedDestination, *localizedOffset); +} + +void createDim0ParallelInsertSlice( + MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) { + auto fragmentType = cast(fragment.getType()); + Dim0SliceParams params = buildDim0SliceParams(state.rewriter, fragmentType, firstOffset, fragmentType.getDimSize(0)); + tensor::ParallelInsertSliceOp::create( + state.rewriter, loc, fragment, destination, params.offsets, params.sizes, params.strides); +} + +Value scaleIndexByDim0Size(MaterializerState& state, Operation* anchor, Value index, int64_t dim0Size, Location loc) { + if (dim0Size == 1) + return index; + + Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, anchor, dim0Size); + return arith::MulIOp::create(state.rewriter, loc, index, dim0SizeValue).getResult(); +} + +FailureOr scaleIndexByDim0SizeInClass(MaterializerState& state, + MaterializedClass& targetClass, + Value index, + int64_t dim0Size, + Location loc) { + FailureOr localizedIndex = rematerializeIndexValueInClass(state, targetClass, index, loc); + if (failed(localizedIndex)) + return failure(); + if (dim0Size == 1) + return *localizedIndex; + + Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, dim0Size); + return arith::MulIOp::create(state.rewriter, loc, *localizedIndex, dim0SizeValue).getResult(); +} + +bool sameProducerResult(ProducerKey lhs, ProducerKey rhs) { + return lhs.instance.op == rhs.instance.op && lhs.resultIndex == rhs.resultIndex; +} + +bool containsProducerKey(ProducerKey outer, ProducerKey inner) { + if (!sameProducerResult(outer, inner)) + return false; + if (!isa(outer.instance.op)) + return false; + if (outer.instance.laneCount == 0 || inner.instance.laneCount == 0) + return false; + + uint32_t outerStart = outer.instance.laneStart; + uint32_t outerEnd = outerStart + outer.instance.laneCount; + uint32_t innerStart = inner.instance.laneStart; + uint32_t innerEnd = innerStart + inner.instance.laneCount; + + return outerStart <= innerStart && innerEnd <= outerEnd; +} + +std::optional extractPackedProducerSlice(MaterializerState& state, + MaterializedClass& materializedClass, + ProducerKey packedKey, + Value packed, + ProducerKey requestedKey) { + if (!containsProducerKey(packedKey, requestedKey)) + return std::nullopt; + + auto packedType = dyn_cast(packed.getType()); + if (!packedType || !packedType.hasStaticShape() || packedType.getRank() == 0) + return std::nullopt; + + if (packedKey.instance.laneCount == 0) + return std::nullopt; + + int64_t packedRows = packedType.getDimSize(0); + if (packedRows % static_cast(packedKey.instance.laneCount) != 0) + return std::nullopt; + + int64_t rowsPerLane = packedRows / static_cast(packedKey.instance.laneCount); + int64_t rowOffset = + static_cast(requestedKey.instance.laneStart - packedKey.instance.laneStart) * rowsPerLane; + int64_t rowCount = static_cast(requestedKey.instance.laneCount) * rowsPerLane; + + state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); + + Value firstOffset = getOrCreateIndexConstant(state.constantFolder, materializedClass.op, rowOffset); + FailureOr slice = + createDim0ExtractSliceInClass(state, materializedClass, materializedClass.op->getLoc(), packed, firstOffset, rowCount); + if (failed(slice)) + return std::nullopt; + return *slice; +} + +std::optional AvailableValueStore::lookupExact(ProducerKey key, ClassId classId) const { + auto producerIt = exactValues.find(key); + if (producerIt == exactValues.end()) + return std::nullopt; + + auto valueIt = producerIt->second.find(classId); + if (valueIt == producerIt->second.end()) + return std::nullopt; + + return valueIt->second; +} + +FailureOr getPackedSliceForRunIndex(MaterializerState& state, + MaterializedClass& targetClass, + Value packed, + RankedTensorType fragmentType, + size_t index, + Location loc) { + int64_t rowOffset = static_cast(index) * fragmentType.getDimSize(0); + Value firstOffset = getOrCreateIndexConstant(state.constantFolder, targetClass.op, rowOffset); + return createDim0ExtractSliceInClass(state, targetClass, loc, packed, firstOffset, fragmentType.getDimSize(0)); +} + +FailureOr createReceiveConcatLoop(MaterializerState& state, + MaterializedClass& targetClass, + RankedTensorType concatType, + RankedTensorType fragmentType, + const MessageVector& messages, + Location loc); + +using IndexedFragmentBuilder = llvm::function_ref(Value flatIndex)>; +using IndexedInsertOffsetBuilder = llvm::function_ref(Value flatIndex)>; + +FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& state, + MaterializedClass& targetClass, + PackedScalarRunValue& run, + Location loc); + +SmallVector flattenPackedScalarRunKeys(const PackedScalarRunValue& run); + +bool isDeferredLocalPackedScalarRun(const PackedScalarRunValue& run) { + return run.kind == PackedScalarRunKind::DeferredLocalCompute; +} + +size_t getPackedScalarRunReceiveCount(const PackedScalarRunValue& run) { + size_t count = 0; + for (const PackedScalarRunSlot& slot : run.slots) + count += slot.keys.size(); + return count; +} + +LogicalResult validatePackedScalarRunMetadata(Operation* anchor, const PackedScalarRunValue& run) { + if (run.kind == PackedScalarRunKind::DeferredLocalCompute) + return success(); + + size_t receiveCount = getPackedScalarRunReceiveCount(run); + + if (receiveCount == 0) + return anchor->emitError("packed scalar run has no receives"); + + if (failed(run.messages.verify(anchor))) + return failure(); + + if (run.messages.size() != receiveCount) + return anchor->emitError("packed scalar run receive metadata count is inconsistent"); + + return success(); +} + +FailureOr materializePackedScalarRunValue(MaterializerState& state, + MaterializedClass& targetClass, + PackedScalarRunValue& run, + Location loc) { + if (run.packed) + return run.packed; + + if (run.kind == PackedScalarRunKind::Materialized) + return targetClass.op->emitError("materialized packed scalar run has no packed value"); + + if (isDeferredLocalPackedScalarRun(run)) + return materializeDeferredLocalPackedScalarRunValue(state, targetClass, run, loc); + + if (failed(validatePackedScalarRunMetadata(targetClass.op, run))) + return failure(); + + FailureOr fullPackedType = + getPackedBatchTensorType(run.fragmentType, getPackedScalarRunReceiveCount(run)); + if (failed(fullPackedType)) + return targetClass.op->emitError("cannot create lazy packed scalar run receive type"); + + auto packed = createReceiveConcatLoop(state, targetClass, *fullPackedType, run.fragmentType, run.messages, loc); + if (failed(packed)) + return failure(); + run.packed = *packed; + return run.packed; +} + +std::optional AvailableValueStore::lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId) { + for (PackedScalarRunValue& run : packedScalarRuns) { + if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) + continue; + + size_t flattenedIndexBase = 0; + for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) { + std::optional contiguousKey = getPhysicallyContiguousProducerRangeForKeys(slot.keys); + if (contiguousKey && containsProducerKey(*contiguousKey, key)) { + FailureOr slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size()); + if (failed(slotPackedType)) + return std::nullopt; + + MaterializedClass& materializedClass = state.classes[classId]; + state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); + + FailureOr packed = + materializePackedScalarRunValue(state, materializedClass, run, materializedClass.op->getLoc()); + if (failed(packed)) + return std::nullopt; + FailureOr slotPacked = + getPackedSliceForRunIndex(state, materializedClass, *packed, *slotPackedType, slotIndex, (*packed).getLoc()); + if (failed(slotPacked)) + return std::nullopt; + + if (*contiguousKey == key) { + record(key, classId, *slotPacked); + return *slotPacked; + } + + std::optional sliced = + extractPackedProducerSlice(state, materializedClass, *contiguousKey, *slotPacked, key); + if (!sliced) + return std::nullopt; + + record(key, classId, *sliced); + return *sliced; + } + + auto keyIt = llvm::find(slot.keys, key); + if (keyIt == slot.keys.end()) { + flattenedIndexBase += slot.keys.size(); + continue; + } + + MaterializedClass& materializedClass = state.classes[classId]; + state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); + + FailureOr packed = + materializePackedScalarRunValue(state, materializedClass, run, materializedClass.op->getLoc()); + if (failed(packed)) + return std::nullopt; + size_t flattenedIndex = flattenedIndexBase + static_cast(std::distance(slot.keys.begin(), keyIt)); + FailureOr sliced = + getPackedSliceForRunIndex(state, materializedClass, *packed, run.fragmentType, flattenedIndex, (*packed).getLoc()); + if (failed(sliced)) + return std::nullopt; + record(key, classId, *sliced); + return *sliced; + } + } + + return std::nullopt; +} + +IndexedBatchRunValue* AvailableValueStore::lookupIndexedBatchRun(ProducerKey key, ClassId classId) { + for (IndexedBatchRunValue& run : indexedBatchRuns) { + if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) + continue; + for (const PackedScalarRunSlot& slot : run.slots) { + if (!llvm::is_contained(slot.keys, key)) + continue; + return &run; + } + } + return nullptr; +} + +std::optional AvailableValueStore::lookup(MaterializerState& state, ProducerKey key, ClassId classId) { + + if (std::optional exact = lookupExact(key, classId)) { + return exact; + } + + if (std::optional packedRunValue = lookupPackedRun(state, key, classId)) + return packedRunValue; + + MaterializedClass& materializedClass = state.classes[classId]; + + for (const auto& [candidateKey, classValues] : exactValues) { + if (!sameProducerResult(candidateKey, key) || !containsProducerKey(candidateKey, key)) + continue; + + auto valueIt = classValues.find(classId); + if (valueIt == classValues.end()) + continue; + std::optional slice = + extractPackedProducerSlice(state, materializedClass, candidateKey, valueIt->second, key); + if (!slice) + return std::nullopt; + + record(key, classId, *slice); + return *slice; + } + return std::nullopt; +} + +Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef values) { + SmallVector elements; + elements.reserve(values.size()); + for (int64_t value : values) + elements.push_back(APInt(64, value)); + + auto type = RankedTensorType::get({static_cast(values.size())}, state.rewriter.getIndexType()); + auto attr = DenseIntElementsAttr::get(type, elements); + return getOrCreateConstant(state.constantFolder, anchor, attr, type); +} + +bool allEqual(ArrayRef values) { + assert(!values.empty() && "expected at least one value"); + for (int64_t value : values.drop_front()) + if (value != values.front()) + return false; + return true; +} + +struct IndexedIndexPattern { + int64_t base = 0; + int64_t step = 0; + int64_t period = 1; + int64_t innerStep = 0; + int64_t outerStep = 0; + bool isTiled = false; +}; + +bool matchAffineSequence(ArrayRef values, IndexedIndexPattern& pattern) { + assert(!values.empty() && "expected at least one value"); + + pattern.base = values.front(); + pattern.step = values.size() == 1 ? 0 : values[1] - values[0]; + pattern.isTiled = false; + + for (auto [index, value] : llvm::enumerate(values)) { + int64_t expected = pattern.base + pattern.step * static_cast(index); + if (value != expected) + return false; + } + + return true; +} + +bool matchTiledAffineSequence(ArrayRef values, IndexedIndexPattern& pattern, int64_t period) { + assert(!values.empty() && "expected at least one value"); + if (period < 2 || period > static_cast(values.size() / 2)) + return false; + + int64_t base = values.front(); + int64_t innerStep = values[1] - values[0]; + int64_t outerStep = values[period] - values[0]; + + for (auto [index, value] : llvm::enumerate(values)) { + int64_t i = static_cast(index); + int64_t expected = base + outerStep * (i / period) + innerStep * (i % period); + if (value != expected) + return false; + } + + pattern.base = base; + pattern.period = period; + pattern.innerStep = innerStep; + pattern.outerStep = outerStep; + pattern.isTiled = true; + return true; +} + +bool matchTiledAffineSequence(ArrayRef values, IndexedIndexPattern& pattern) { + assert(!values.empty() && "expected at least one value"); + + for (int64_t period = 2; period <= static_cast(values.size() / 2); ++period) + if (matchTiledAffineSequence(values, pattern, period)) + return true; + + return false; +} + +std::optional getIndexedIndexPattern(ArrayRef values, + std::optional preferredPeriod = std::nullopt, + bool allowExhaustiveTiledSearch = true) { + assert(!values.empty() && "expected at least one value"); + + IndexedIndexPattern pattern; + if (matchAffineSequence(values, pattern)) + return pattern; + if (preferredPeriod && matchTiledAffineSequence(values, pattern, *preferredPeriod)) + return pattern; + if (allowExhaustiveTiledSearch && values.size() <= 256 && matchTiledAffineSequence(values, pattern)) + return pattern; + + return std::nullopt; +} + +Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) { + MLIRContext* context = state.func.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + + AffineExpr expr; + if (!pattern.isTiled) { + expr = getAffineConstantExpr(pattern.base, context) + d0 * pattern.step; + } + else { + expr = getAffineConstantExpr(pattern.base, context) + d0.floorDiv(pattern.period) * pattern.outerStep + + (d0 % pattern.period) * pattern.innerStep; + } + + AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); + return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {index}, state.func); +} + +Value createIndexedIndexValue(MaterializerState& state, + Operation* anchor, + ArrayRef values, + Value index, + Location loc, + std::optional preferredPeriod, + bool allowExhaustiveTiledSearch) { + assert(!values.empty() && "expected at least one indexed value"); + + if (allEqual(values)) { + return getOrCreateIndexConstant(state.constantFolder, anchor, values.front()); + } + + if (std::optional pattern = + getIndexedIndexPattern(values, preferredPeriod, allowExhaustiveTiledSearch)) + return createAffineIndexValue(state, *pattern, index, loc); + Value table = createIndexTensorConstant(state, anchor, values); + return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult(); +} + +Value createIndexedIndexValue( + MaterializerState& state, Operation* anchor, ArrayRef values, Value index, Location loc) { + assert(!values.empty() && "expected at least one indexed value"); + + SmallVector widened; + widened.reserve(values.size()); + for (int32_t value : values) + widened.push_back(value); + + return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, std::nullopt, true); +} + +OpFoldResult createIndexedOrStaticIndex(MaterializerState& state, + Operation* anchor, + ArrayRef values, + Value index, + Location loc) { + assert(!values.empty() && "expected at least one indexed value"); + if (allEqual(values)) + return state.rewriter.getIndexAttr(values.front()); + return createIndexedIndexValue(state, anchor, values, index, loc); +} + +Value createIndexedChannelId( + MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { + return createIndexedIndexValue(state, anchor, ArrayRef(messages.channelIds), index, loc); +} + +Value createIndexedChannelId(MaterializerState& state, + Operation* anchor, + const MessageVector& messages, + Value index, + Location loc, + std::optional preferredPeriod) { + return createIndexedIndexValue( + state, anchor, ArrayRef(messages.channelIds), index, loc, preferredPeriod, true); +} + +Value createIndexedSourceCoreId( + MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { + return createIndexedIndexValue(state, anchor, ArrayRef(messages.sourceCoreIds), index, loc); +} + +Value createIndexedSourceCoreId(MaterializerState& state, + Operation* anchor, + const MessageVector& messages, + Value index, + Location loc, + std::optional preferredPeriod) { + SmallVector widened(messages.sourceCoreIds.begin(), messages.sourceCoreIds.end()); + return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, preferredPeriod, true); +} + +Value createIndexedTargetCoreId( + MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { + return createIndexedIndexValue(state, anchor, ArrayRef(messages.targetCoreIds), index, loc); +} + +Value createIndexedTargetCoreId(MaterializerState& state, + Operation* anchor, + const MessageVector& messages, + Value index, + Location loc, + std::optional preferredPeriod) { + SmallVector widened(messages.targetCoreIds.begin(), messages.targetCoreIds.end()); + return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, preferredPeriod, true); +} + +Value createLaneIndexedIndexValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef values, + Location loc) { + assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class"); + assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane"); + + auto batch = cast(materializedClass.op); + auto laneArg = batch.getLaneArgument(); + assert(laneArg && "expected compute_batch lane argument"); + + return createIndexedIndexValue(state, materializedClass.op, values, *laneArg, loc); +} + +Value createLaneIndexedIndexValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef values, + Location loc) { + assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class"); + assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane"); + + SmallVector widened; + widened.reserve(values.size()); + for (int32_t value : values) + widened.push_back(value); + + return createLaneIndexedIndexValue(state, materializedClass, ArrayRef(widened), loc); +} + +FailureOr remapProjectionIndexLike(MaterializerState& state, + Operation* anchor, + OpFoldResult value, + Value sourceLaneArg, + Value mappedLaneValue, + Location loc) { + if (auto attr = dyn_cast(value)) + return value; + + Value operand = cast(value); + if (operand == sourceLaneArg) + return OpFoldResult(mappedLaneValue); + + if (matchPattern(operand, m_Constant())) + return getAsOpFoldResult(operand); + + auto affineApply = operand.getDefiningOp(); + if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) + return failure(); + + SmallVector remappedOperands; + remappedOperands.reserve(affineApply.getMapOperands().size()); + for (Value mapOperand : affineApply.getMapOperands()) { + FailureOr remapped = + remapProjectionIndexLike(state, anchor, OpFoldResult(mapOperand), sourceLaneArg, mappedLaneValue, loc); + if (failed(remapped)) + return failure(); + remappedOperands.push_back(getValueOrCreateConstantIndexOp(state.rewriter, loc, *remapped)); + } + + return getAsOpFoldResult( + createOrFoldAffineApply(state.rewriter, loc, affineApply.getAffineMap(), remappedOperands, state.func)); +} + +FailureOr createProjectionLaneValueForKeys(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + Location loc) { + if (!sourceClass.isBatch) + return sourceClass.op->emitError("projection lane mapping expects a batch materialized class"); + + auto batch = cast(sourceClass.op); + auto laneArg = batch.getLaneArgument(); + if (!laneArg) + return batch.emitOpError("missing lane argument for projected batch host publication"); + + if (keys.size() == 1) { + if (keys.front().instance.laneCount != 1) + return batch.emitOpError("projected batch host publication expects one logical lane per fragment"); + return getOrCreateIndexConstant(state.constantFolder, sourceClass.op, keys.front().instance.laneStart); + } + + if (keys.size() != sourceClass.cpus.size()) + return batch.emitOpError("projected batch host publication expected one producer key per materialized batch lane"); + + SmallVector sourceLanes; + sourceLanes.reserve(keys.size()); + for (ProducerKey key : keys) { + if (key.instance.laneCount != 1) + return batch.emitOpError("projected batch host publication expects one logical lane per fragment"); + sourceLanes.push_back(key.instance.laneStart); + } + + return createIndexedIndexValue(state, sourceClass.op, sourceLanes, *laneArg, loc, std::nullopt, true); +} + +FailureOr> +getPeerLogicalInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId logicalSlot) { + SmallVector peers; + peers.reserve(materializedClass.cpus.size()); + for (CpuId cpu : materializedClass.cpus) { + auto streamIt = state.logicalInstancesByCpu.find(cpu); + if (streamIt == state.logicalInstancesByCpu.end() || logicalSlot >= streamIt->second.size()) + return failure(); + peers.push_back(streamIt->second[logicalSlot]); + } + return peers; +} + +Value createOriginalLaneValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef peers, + Location loc) { + assert(!peers.empty() && "expected at least one peer instance"); + if (!materializedClass.isBatch) + return getOrCreateIndexConstant(state.constantFolder, materializedClass.op, peers.front().laneStart); + + auto batch = cast(materializedClass.op); + auto laneArg = batch.getLaneArgument(); + assert(laneArg && "expected materialized compute_batch lane argument"); + + SmallVector laneValues; + laneValues.reserve(peers.size()); + for (const ComputeInstance& peer : peers) + laneValues.push_back(peer.laneStart); + + return createIndexedIndexValue(state, materializedClass.op, ArrayRef(laneValues), *laneArg, loc); +} + +bool hasLiveExternalUse(Value value, const DenseSet& oldComputeOps) { + SmallVector worklist {value}; + DenseSet visited; + + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + if (!visited.insert(current).second) + continue; + + for (OpOperand& use : current.getUses()) { + Operation* owner = use.getOwner(); + if (isInsideOldCompute(owner, oldComputeOps)) + continue; + if (isa(owner)) { + for (Value result : owner->getResults()) + worklist.push_back(result); + continue; + } + return true; + } + } + + return false; +} + +bool hasRealComputeConsumer(Value value, const DenseSet& oldComputeOps) { + SmallVector worklist {value}; + DenseSet visited; + + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + if (!visited.insert(current).second) + continue; + + for (OpOperand& use : current.getUses()) { + Operation* owner = use.getOwner(); + if (isInsideOldCompute(owner, oldComputeOps)) + continue; + if (isa(owner)) { + for (Value result : owner->getResults()) + worklist.push_back(result); + continue; + } + if (isa(owner)) + continue; + return true; + } + } + + return false; +} + +FailureOr +getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex); + +bool isTerminalHostBatchOutput(Value output, const DenseSet& oldComputeOps) { + auto batch = dyn_cast_or_null(output.getDefiningOp()); + if (!batch || batch.getNumResults() == 0) + return false; + if (!hasLiveExternalUse(output, oldComputeOps)) + return false; + return !hasRealComputeConsumer(output, oldComputeOps); +} + +bool isProjectedTerminalBatchHostOutput(Value output, const DenseSet& oldComputeOps) { + if (!isTerminalHostBatchOutput(output, oldComputeOps)) + return false; + + auto batch = dyn_cast_or_null(output.getDefiningOp()); + auto originalResult = dyn_cast(output); + if (!batch || !originalResult) + return false; + + FailureOr projection = + getBatchResultProjectionInsert(batch, originalResult.getResultNumber()); + if (failed(projection)) + return false; + + return projection->getSource().getType() != output.getType(); +} + +LogicalResult emitBatchToScalarDestinationDiagnostic(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + Value originalOutput) { + auto diag = sourceClass.op->emitError("resultful compute_batch output would enter batch-to-scalar class fanout"); + diag << " sourceClassId=" << sourceClass.id << " sourceKind=" << (sourceClass.isBatch ? "batch" : "scalar"); + diag << " liveExternalUse=" << (hasLiveExternalUseCached(state, originalOutput) ? "true" : "false"); + diag << " terminalHostBatch=" << (isTerminalHostBatchOutput(originalOutput, state.oldComputeOps) ? "true" : "false"); + diag << " originalDef=" + << (originalOutput.getDefiningOp() ? originalOutput.getDefiningOp()->getName().getStringRef() : StringRef("")); + + bool first = true; + diag << " destinationClasses=["; + auto destIt = state.producerDestClasses.find(keys.front()); + ArrayRef destinations = destIt == state.producerDestClasses.end() ? ArrayRef {} : ArrayRef(destIt->second); + for (ClassId classId : destinations) { + if (!first) + diag << ", "; + first = false; + const MaterializedClass& destClass = state.classes[classId]; + diag << classId << ":" << (destClass.isBatch ? "batch" : "scalar"); + } + diag << "]"; + + diag << " producerKeys=["; + first = true; + for (ProducerKey key : keys) { + if (!first) + diag << ", "; + first = false; + diag << key.instance.op->getName().getStringRef() << ":r" << key.resultIndex << ":laneStart=" << key.instance.laneStart + << ":laneCount=" << key.instance.laneCount; + } + diag << "]"; + return failure(); +} + +void appendDestinationClass(MaterializerState& state, ProducerKey key, ClassId classId) { + SmallVector& destinations = state.producerDestClasses[key]; + if (!llvm::is_contained(destinations, classId)) + destinations.push_back(classId); +} + +void replaceLiveExternalUses(Value oldValue, Value replacement, const DenseSet& oldComputeOps) { + SmallVector uses; + for (OpOperand& use : oldValue.getUses()) + uses.push_back(&use); + + for (OpOperand* use : uses) { + Operation* owner = use->getOwner(); + if (isInsideOldCompute(owner, oldComputeOps)) + continue; + use->set(replacement); + } +} + +LogicalResult collectProducerDestinations(MaterializerState& state) { + return forEachLogicalConsumerInMaterializationOrder( + state, + [&](CpuId, ClassId targetClass, ComputeInstance scheduledConsumer, ComputeInstance logicalConsumer, SlotId) + -> LogicalResult { + for (Value input : getComputeInstanceInputs(scheduledConsumer)) { + for (ProducerKey producerKey : collectProducerKeysForDestinations(input, logicalConsumer)) { + ComputeInstance scheduledProducer = getScheduledChunkForLogicalInstance(state, producerKey.instance); + auto producerCpuIt = state.schedule.computeToCpuMap.find(scheduledProducer); + if (producerCpuIt == state.schedule.computeToCpuMap.end()) + return logicalConsumer.op->emitError( + "schedule materialization found an input produced by an unscheduled compute"); + + ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second); + if (sourceClass == targetClass) { + SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, targetClass}; + SmallVector& bucket = state.sameClassConsumerIndex[lookupKey]; + if (!llvm::is_contained(bucket, producerKey)) + bucket.push_back(producerKey); + continue; + } + + appendDestinationClass(state, producerKey, targetClass); + } + } + + return success(); + }); +} + +bool isStaticSliceInBounds(ArrayRef offsets, RankedTensorType sourceType, RankedTensorType fragmentType) { + if (offsets.size() != static_cast(sourceType.getRank()) + || offsets.size() != static_cast(fragmentType.getRank())) + return false; + + for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { + int64_t offset = offsets[dim]; + if (offset < 0) + return false; + + int64_t sourceDimSize = sourceType.getDimSize(dim); + int64_t fragmentDimSize = fragmentType.getDimSize(dim); + if (fragmentDimSize < 0 || sourceDimSize < 0 || fragmentDimSize > sourceDimSize) + return false; + if (offset > sourceDimSize - fragmentDimSize) + return false; + } + + return true; +} + + +bool isStaticSliceContainedIn(ArrayRef innerOffsets, + ArrayRef innerSizes, + ArrayRef outerOffsets, + ArrayRef outerSizes) { + if (innerOffsets.size() != innerSizes.size() || outerOffsets.size() != outerSizes.size() + || innerOffsets.size() != outerOffsets.size()) + return false; + + for (size_t dim = 0; dim < innerOffsets.size(); ++dim) { + if (innerSizes[dim] < 0 || outerSizes[dim] < 0) + return false; + + int64_t innerBegin = innerOffsets[dim]; + int64_t innerEnd = innerBegin + innerSizes[dim]; + int64_t outerBegin = outerOffsets[dim]; + int64_t outerEnd = outerBegin + outerSizes[dim]; + if (innerBegin < outerBegin || innerEnd > outerEnd) + return false; + } + + return true; +} + +bool areAllUnitStrides(ArrayRef strides) { + return llvm::all_of(strides, [](int64_t stride) { return stride == 1; }); +} + +static std::optional getStaticForTripCount(scf::ForOp loop) { + std::optional lowerBound = matchConstantIndexValue(loop.getLowerBound()); + std::optional upperBound = matchConstantIndexValue(loop.getUpperBound()); + std::optional step = matchConstantIndexValue(loop.getStep()); + if (!lowerBound || !upperBound || !step || *step <= 0 || *upperBound < *lowerBound) + return std::nullopt; + + int64_t distance = *upperBound - *lowerBound; + return (distance + *step - 1) / *step; +} + +static SmallVector collectEnclosingStaticProjectedLoops(Operation* op) { + SmallVector loops; + SmallVector reversedLoops; + for (Operation* current = op->getParentOp(); current; current = current->getParentOp()) + if (auto loop = dyn_cast(current)) + reversedLoops.push_back(loop); + + for (scf::ForOp loop : llvm::reverse(reversedLoops)) { + std::optional lowerBound = matchConstantIndexValue(loop.getLowerBound()); + std::optional step = matchConstantIndexValue(loop.getStep()); + std::optional tripCount = getStaticForTripCount(loop); + if (!lowerBound || !step || !tripCount) + return {}; + loops.push_back(StaticProjectedLoopInfo {.iv = cast(loop.getInductionVar()), + .lowerBound = *lowerBound, + .step = *step, + .tripCount = *tripCount}); + } + return loops; +} + +static bool +isProjectedOffsetValue(Value value, Value laneArg, ArrayRef loops, bool& usesDynamicBinding) { + if (value == laneArg) { + usesDynamicBinding = true; + return true; + } + + for (const StaticProjectedLoopInfo& loop : loops) { + if (value == loop.iv) { + usesDynamicBinding = true; + return true; + } + } + + if (matchPattern(value, m_Constant())) + return true; + + auto affineApply = value.getDefiningOp(); + if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) + return false; + + bool nestedUsesDynamicBinding = false; + for (Value operand : affineApply.getMapOperands()) { + bool operandUsesDynamicBinding = false; + if (!isProjectedOffsetValue(operand, laneArg, loops, operandUsesDynamicBinding)) + return false; + nestedUsesDynamicBinding = nestedUsesDynamicBinding || operandUsesDynamicBinding; + } + + usesDynamicBinding = usesDynamicBinding || nestedUsesDynamicBinding; + return true; +} + +static std::optional getConstantIndex(OpFoldResult value); + +static unsigned getProjectedFragmentsPerLogicalSlot(ArrayRef loopTripCounts) { + unsigned fragmentsPerLogicalSlot = 1; + for (int64_t tripCount : loopTripCounts) { + assert(tripCount > 0 && "projected loop trip counts must be positive"); + fragmentsPerLogicalSlot *= static_cast(tripCount); + } + return fragmentsPerLogicalSlot; +} + +LogicalResult verifyProjectedFragmentLayout(Operation* anchor, const ProjectedFragmentLayout& layout) { + if (!layout.fragmentType || layout.fragmentShape.empty()) + return anchor->emitError("projected fragment layout is missing fragment type metadata"); + if (layout.fragmentShape.size() != static_cast(layout.fragmentType.getRank())) + return anchor->emitError("projected fragment layout rank does not match fragment type"); + if (layout.payloadFragmentCount == 0 || layout.fragmentsPerLogicalSlot == 0) + return anchor->emitError("projected fragment layout has an invalid fragment count"); + if (layout.payloadFragmentCount % layout.fragmentsPerLogicalSlot != 0) + return anchor->emitError("projected fragment layout payload fragment count is incompatible with logical slots"); + return success(); +} + +FailureOr +getProjectedPayloadType(Operation* anchor, RankedTensorType fragmentType, unsigned payloadFragmentCount) { + if (failed( + verifyPackableFragmentType(anchor, fragmentType, payloadFragmentCount, "cannot create projected payload type"))) + return failure(); + return getPackedBatchTensorType(fragmentType, payloadFragmentCount); +} + +SmallVector, 4> +buildProjectedFragmentOffsetsByDim(ArrayRef> fragmentOffsets, size_t rank) { + SmallVector, 4> fragmentOffsetsByDim(rank); + for (ArrayRef offsets : fragmentOffsets) { + assert(offsets.size() == rank && "projected offset rank mismatch"); + for (size_t dim = 0; dim < rank; ++dim) + fragmentOffsetsByDim[dim].push_back(offsets[dim]); + } + return fragmentOffsetsByDim; +} + +LogicalResult verifyProjectedTransferDescriptor(Operation* anchor, const ProjectedTransferDescriptor& descriptor) { + if (failed(verifyProjectedFragmentLayout(anchor, descriptor.layout))) + return failure(); + if (!descriptor.payloadType) + return anchor->emitError("projected transfer descriptor is missing payload type"); + if (descriptor.fragmentOffsets.empty()) + return anchor->emitError("projected transfer descriptor expected at least one fragment offset"); + if (descriptor.fragmentOffsetsByDim.size() != descriptor.layout.fragmentShape.size()) + return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent"); + for (ArrayRef dimOffsets : descriptor.fragmentOffsetsByDim) + if (dimOffsets.size() != descriptor.fragmentOffsets.size()) + return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent"); + for (ArrayRef offsets : descriptor.fragmentOffsets) + if (offsets.size() != descriptor.layout.fragmentShape.size()) + return anchor->emitError("projected transfer offset rank does not match fragment rank"); + return success(); +} + +LogicalResult verifyProjectedSendDescriptor(Operation* anchor, + const ProjectedTransferDescriptor& descriptor, + const MessageVector& messages) { + if (failed(verifyProjectedTransferDescriptor(anchor, descriptor))) + return failure(); + if (messages.size() * descriptor.layout.payloadFragmentCount != descriptor.fragmentOffsets.size()) + return anchor->emitError("projected send descriptor metadata is inconsistent"); + return success(); +} + +LogicalResult finalizeProjectedTransferDescriptor(Operation* anchor, ProjectedTransferDescriptor& descriptor) { + descriptor.fragmentOffsetsByDim = + buildProjectedFragmentOffsetsByDim(descriptor.fragmentOffsets, descriptor.layout.fragmentShape.size()); + + FailureOr payloadType = + getProjectedPayloadType(anchor, descriptor.layout.fragmentType, descriptor.layout.payloadFragmentCount); + if (failed(payloadType)) + return failure(); + if (descriptor.payloadType && descriptor.payloadType != *payloadType) + return anchor->emitError("projected transfer descriptor payload type does not match projected layout"); + descriptor.payloadType = *payloadType; + + return verifyProjectedTransferDescriptor(anchor, descriptor); +} + +static FailureOr evaluateProjectedOffsetValue(OpFoldResult value, + Value laneArg, + uint32_t lane, + ArrayRef loops, + ArrayRef loopIterationIndices) { + if (std::optional constant = getConstantIndex(value)) + return *constant; + + Value current = dyn_cast(value); + if (!current) + return failure(); + if (current == laneArg) + return static_cast(lane); + + for (auto [index, loop] : llvm::enumerate(loops)) { + if (current != loop.iv) + continue; + if (index >= loopIterationIndices.size()) + return failure(); + return loop.lowerBound + loopIterationIndices[index] * loop.step; + } + + if (auto affineApply = current.getDefiningOp()) { + return evaluateAffineApply(affineApply, [&](Value operand) { + return evaluateProjectedOffsetValue(operand, laneArg, lane, loops, loopIterationIndices); + }); + } + + return failure(); +} + +static std::optional getConstantIndex(OpFoldResult value) { + if (auto attr = dyn_cast(value)) { + auto intAttr = dyn_cast(attr); + if (!intAttr) + return std::nullopt; + return intAttr.getInt(); + } + + Value operand = dyn_cast(value); + if (!operand) + return std::nullopt; + + if (auto constantIndex = operand.getDefiningOp()) + return constantIndex.value(); + + APInt apInt; + if (matchPattern(operand, m_ConstantInt(&apInt))) { + if (apInt.isNegative()) + return std::nullopt; + return static_cast(apInt.getSExtValue()); + } + + return std::nullopt; +} + +static std::optional matchAffineProjectedInputSlice(SpatComputeBatch batch, + unsigned inputIndex) { + const auto fail = [&](StringRef) -> std::optional { return std::nullopt; }; + + std::optional inputArg = batch.getInputArgument(inputIndex); + std::optional laneArg = batch.getLaneArgument(); + if (!inputArg || !laneArg) + return fail("missing-input-or-lane-arg"); + + if (!inputArg->hasOneUse()) + return fail("input-arg-not-one-use"); + + Operation* user = *inputArg->getUsers().begin(); + auto extract = dyn_cast(user); + if (!extract || extract.getSource() != *inputArg) + return fail("input-user-is-not-direct-extract-slice"); + + auto inputType = dyn_cast(inputArg->getType()); + auto fragmentType = dyn_cast(extract.getResult().getType()); + if (!inputType || !fragmentType || !inputType.hasStaticShape() || !fragmentType.hasStaticShape()) + return fail("non-static-ranked-input-or-fragment"); + + if (inputType.getRank() == 0 || inputType.getRank() != fragmentType.getRank()) + return fail("rank-mismatch-or-rank-zero"); + + SmallVector offsets = extract.getMixedOffsets(); + SmallVector sizes = extract.getMixedSizes(); + SmallVector strides = extract.getMixedStrides(); + + if (offsets.size() != static_cast(inputType.getRank()) + || sizes.size() != static_cast(inputType.getRank()) + || strides.size() != static_cast(inputType.getRank())) + return fail("slice-rank-mismatch"); + + SmallVector loops = collectEnclosingStaticProjectedLoops(extract.getOperation()); + if (extract->getParentOfType() && loops.empty()) + return fail("unsupported-enclosing-loop"); + + bool hasDynamicProjection = false; + for (auto [dim, offset] : llvm::enumerate(offsets)) { + bool usesDynamicBinding = false; + if (auto value = dyn_cast(offset)) { + if (!isProjectedOffsetValue(value, *laneArg, loops, usesDynamicBinding)) + return std::nullopt; + } + else if (!isa(offset)) + return std::nullopt; + if (std::optional stride = getConstantIndex(strides[dim]); !stride || *stride != 1) + return std::nullopt; + std::optional size = getConstantIndex(sizes[dim]); + if (!size || *size != fragmentType.getDimSize(dim)) + return std::nullopt; + hasDynamicProjection = hasDynamicProjection || usesDynamicBinding; + } + + if (!hasDynamicProjection) + return fail("no-dynamic-projection"); + + for (int64_t dim = 0; dim < inputType.getRank(); ++dim) + if (fragmentType.getDimSize(dim) <= 0 || fragmentType.getDimSize(dim) > inputType.getDimSize(dim)) + return std::nullopt; + + AffineProjectedInputSliceMatch match; + match.extract = extract; + match.sourceType = inputType; + match.fragmentType = fragmentType; + match.offsets.assign(offsets.begin(), offsets.end()); + match.fragmentShape.assign(fragmentType.getShape().begin(), fragmentType.getShape().end()); + match.loops = std::move(loops); + return match; +} + +std::optional +getProjectedInputSliceMatch(MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex) { + ProjectedBatchInputKey key {batch.getOperation(), inputIndex}; + auto cached = state.projectedInputMatches.find(key); + if (cached != state.projectedInputMatches.end()) + return cached->second; + if (state.nonProjectedInputs.contains(key)) + return std::nullopt; + + std::optional match = matchAffineProjectedInputSlice(batch, inputIndex); + if (!match) { + state.nonProjectedInputs.insert(key); + return std::nullopt; + } + + state.projectedInputMatches.insert({key, *match}); + return match; +} + +LogicalResult collectProjectedTransfers(MaterializerState& state) { + struct PendingProjectedTransferDescriptor { + ProjectedBatchInputKey inputKey; + Operation* extractOp = nullptr; + RankedTensorType sourceType; + RankedTensorType fragmentType; + SmallVector fragmentShape; + SmallVector, 16>, 8> fragmentOffsetsByLane; + SmallVector loopLowerBounds; + SmallVector loopSteps; + SmallVector loopTripCounts; + bool invalid = false; + }; + + DenseMap, ProducerKeyInfo> pending; + + const auto isIdentityProjectedTransfer = [&](const PendingProjectedTransferDescriptor& descriptor) { + if (!descriptor.sourceType || descriptor.sourceType != descriptor.fragmentType) + return false; + + if (descriptor.fragmentOffsetsByLane.size() != 1) + return false; + + ArrayRef> fragments = descriptor.fragmentOffsetsByLane.front(); + if (fragments.size() != 1) + return false; + + return llvm::all_of(fragments.front(), [](int64_t offset) { return offset == 0; }); + }; + + const auto appendEvaluatedFragments = [&](PendingProjectedTransferDescriptor& descriptor, + unsigned targetLane, + const AffineProjectedInputSliceMatch& match, + Value laneArg, + uint32_t lane) -> LogicalResult { + SmallVector loopIterationIndices; + loopIterationIndices.resize(match.loops.size(), 0); + + const auto appendOneFragment = [&]() -> LogicalResult { + SmallVector evaluatedOffsets; + evaluatedOffsets.reserve(match.offsets.size()); + for (OpFoldResult offset : match.offsets) { + FailureOr evaluated = + evaluateProjectedOffsetValue(offset, laneArg, lane, match.loops, loopIterationIndices); + if (failed(evaluated)) + return failure(); + evaluatedOffsets.push_back(*evaluated); + } + + if (!isStaticSliceInBounds(evaluatedOffsets, match.sourceType, match.fragmentType)) + return failure(); + + descriptor.fragmentOffsetsByLane[targetLane].push_back(std::move(evaluatedOffsets)); + return success(); + }; + + if (match.loops.empty()) + return appendOneFragment(); + + const auto recurse = [&](auto&& self, size_t loopIndex) -> LogicalResult { + if (loopIndex == match.loops.size()) + return appendOneFragment(); + + for (int64_t iteration = 0; iteration < match.loops[loopIndex].tripCount; ++iteration) { + loopIterationIndices[loopIndex] = iteration; + if (failed(self(self, loopIndex + 1))) + return failure(); + } + return success(); + }; + + return recurse(recurse, 0); + }; + + if (failed(forEachLogicalConsumerInMaterializationOrder( + state, + [&](CpuId cpu, + ClassId targetClassId, + ComputeInstance consumer, + ComputeInstance logicalConsumer, + SlotId logicalSlot) -> LogicalResult { + auto batch = dyn_cast(consumer.op); + if (!batch) + return success(); + + MaterializedClass& targetClass = state.classes[targetClassId]; + unsigned targetLane = 0; + if (targetClass.isBatch) { + auto targetLaneIt = targetClass.cpuToLane.find(cpu); + if (targetLaneIt == targetClass.cpuToLane.end()) + return consumer.op->emitError("projected transfer collection could not recover target lane"); + targetLane = targetLaneIt->second; + } + + for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) { + SmallVector producers = collectProducerKeysForDestinations(input, logicalConsumer); + if (producers.size() != 1) + continue; + ProducerKey producer = producers.front(); + + ComputeInstance scheduledProducer = getScheduledChunkForLogicalInstance(state, producer.instance); + auto producerCpuIt = state.schedule.computeToCpuMap.find(scheduledProducer); + if (producerCpuIt == state.schedule.computeToCpuMap.end()) + continue; + + ClassId sourceClassId = state.cpuToClass.lookup(producerCpuIt->second); + if (sourceClassId == targetClassId) + continue; + + std::optional match = + getProjectedInputSliceMatch(state, batch, static_cast(inputIndex)); + if (!match) + continue; + if (!isProjectedInputSliceCompatibleWithProducerFragments( + batch, *match, producer, logicalConsumer.laneStart)) + continue; + + PendingProjectedTransferDescriptor& descriptor = pending[producer][targetClassId]; + if (descriptor.fragmentOffsetsByLane.empty()) { + descriptor.inputKey = {batch.getOperation(), static_cast(inputIndex)}; + descriptor.extractOp = match->extract.getOperation(); + descriptor.sourceType = match->sourceType; + descriptor.fragmentType = match->fragmentType; + descriptor.fragmentShape = match->fragmentShape; + descriptor.fragmentOffsetsByLane.resize(targetClass.isBatch ? targetClass.cpus.size() : 1); + descriptor.loopLowerBounds.reserve(match->loops.size()); + descriptor.loopSteps.reserve(match->loops.size()); + descriptor.loopTripCounts.reserve(match->loops.size()); + for (const StaticProjectedLoopInfo& loop : match->loops) { + descriptor.loopLowerBounds.push_back(loop.lowerBound); + descriptor.loopSteps.push_back(loop.step); + descriptor.loopTripCounts.push_back(loop.tripCount); + } + } + + ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast(inputIndex)}; + if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != match->extract.getOperation() + || descriptor.sourceType != match->sourceType || descriptor.fragmentType != match->fragmentType + || descriptor.fragmentShape != match->fragmentShape + || descriptor.loopLowerBounds.size() != match->loops.size()) { + descriptor.invalid = true; + continue; + } + for (auto [index, loop] : llvm::enumerate(match->loops)) { + if (descriptor.loopLowerBounds[index] != loop.lowerBound || descriptor.loopSteps[index] != loop.step + || descriptor.loopTripCounts[index] != loop.tripCount) { + descriptor.invalid = true; + break; + } + } + if (descriptor.invalid) + continue; + + if (targetLane >= descriptor.fragmentOffsetsByLane.size()) { + descriptor.invalid = true; + continue; + } + + if (failed(appendEvaluatedFragments( + descriptor, targetLane, *match, *batch.getLaneArgument(), logicalConsumer.laneStart))) { + descriptor.invalid = true; + continue; + } + + (void) logicalSlot; + } + + return success(); + }))) + return failure(); + + for (auto& producerEntry : pending) { + ProducerKey producer = producerEntry.first; + for (auto& classEntry : producerEntry.second) { + ClassId targetClassId = classEntry.first; + PendingProjectedTransferDescriptor& pendingDescriptor = classEntry.second; + + if (pendingDescriptor.invalid) + continue; + if (pendingDescriptor.fragmentOffsetsByLane.empty()) + continue; + if (isIdentityProjectedTransfer(pendingDescriptor)) + continue; + + MaterializedClass& targetClass = state.classes[targetClassId]; + ProjectedTransferDescriptor descriptor; + descriptor.inputKey = pendingDescriptor.inputKey; + descriptor.extractOp = pendingDescriptor.extractOp; + descriptor.layout.fragmentType = pendingDescriptor.fragmentType; + descriptor.layout.fragmentShape = pendingDescriptor.fragmentShape; + descriptor.layout.loopLowerBounds = pendingDescriptor.loopLowerBounds; + descriptor.layout.loopSteps = pendingDescriptor.loopSteps; + descriptor.layout.loopTripCounts = pendingDescriptor.loopTripCounts; + descriptor.layout.fragmentsPerLogicalSlot = getProjectedFragmentsPerLogicalSlot(descriptor.layout.loopTripCounts); + if (targetClass.isBatch) { + unsigned payloadFragmentCount = pendingDescriptor.fragmentOffsetsByLane.front().size(); + if (payloadFragmentCount == 0) + continue; + + // Batch-target projected replacements currently select fragments with the + // local materialization-run slot index. That is only unambiguous when each + // target lane receives one projected fragment. Multi-fragment payloads + // need an explicit producer-key to payload-slot mapping; otherwise two + // independently materialized runs can both select fragment 0 from the same + // packed receive and duplicate rows. + if (payloadFragmentCount != 1) + continue; + + bool uniform = true; + for (ArrayRef> laneFragments : pendingDescriptor.fragmentOffsetsByLane) { + if (laneFragments.size() != payloadFragmentCount) { + uniform = false; + break; + } + } + if (!uniform) + continue; + + descriptor.layout.payloadFragmentCount = payloadFragmentCount; + descriptor.fragmentOffsets.reserve(pendingDescriptor.fragmentOffsetsByLane.size() * payloadFragmentCount); + for (ArrayRef> laneFragments : pendingDescriptor.fragmentOffsetsByLane) + llvm::append_range(descriptor.fragmentOffsets, laneFragments); + } + else { + if (pendingDescriptor.fragmentOffsetsByLane.size() != 1) + return targetClass.op->emitError("scalar projected transfer descriptor expected one local offset stream"); + if (pendingDescriptor.fragmentOffsetsByLane.front().empty()) + continue; + + descriptor.layout.payloadFragmentCount = pendingDescriptor.fragmentOffsetsByLane.front().size(); + llvm::append_range(descriptor.fragmentOffsets, pendingDescriptor.fragmentOffsetsByLane.front()); + if (descriptor.fragmentOffsets.size() != descriptor.layout.payloadFragmentCount) + return targetClass.op->emitError("scalar projected transfer offset count does not match the local run"); + } + if (failed(finalizeProjectedTransferDescriptor(targetClass.op, descriptor))) + return failure(); + + state.projectedTransfers[producer][targetClassId] = std::move(descriptor); + } + } + + return success(); +} + +static std::optional +collectScalarTargetProjectedDescriptor(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef keys, + bool requirePackedRunOffsetCountMatch) { + assert(!targetClass.isBatch && "scalar target projected descriptor helper expects a scalar class"); + + std::optional combined; + for (ProducerKey key : keys) { + auto producerIt = state.projectedTransfers.find(key); + if (producerIt == state.projectedTransfers.end()) + return std::nullopt; + + auto descriptorIt = producerIt->second.find(targetClass.id); + if (descriptorIt == producerIt->second.end()) + return std::nullopt; + + const ProjectedTransferDescriptor& descriptor = descriptorIt->second; + if (descriptor.fragmentOffsets.empty()) + return std::nullopt; + if (descriptor.layout.payloadFragmentCount == 0 || descriptor.layout.fragmentsPerLogicalSlot == 0) + return std::nullopt; + if (descriptor.fragmentOffsets.size() != descriptor.layout.payloadFragmentCount) + return std::nullopt; + if (descriptor.layout.payloadFragmentCount % descriptor.layout.fragmentsPerLogicalSlot != 0) + return std::nullopt; + + if (!combined) { + combined = descriptor; + continue; + } + + if (!(combined->inputKey == descriptor.inputKey) || combined->extractOp != descriptor.extractOp + || combined->layout.fragmentType != descriptor.layout.fragmentType + || combined->layout.fragmentShape != descriptor.layout.fragmentShape + || combined->layout.loopLowerBounds != descriptor.layout.loopLowerBounds + || combined->layout.loopSteps != descriptor.layout.loopSteps + || combined->layout.loopTripCounts != descriptor.layout.loopTripCounts + || combined->layout.fragmentsPerLogicalSlot != descriptor.layout.fragmentsPerLogicalSlot) + return std::nullopt; + + combined->layout.payloadFragmentCount += descriptor.layout.payloadFragmentCount; + llvm::append_range(combined->fragmentOffsets, descriptor.fragmentOffsets); + } + + if (!combined) + return std::nullopt; + + if (combined->fragmentOffsets.size() != combined->layout.payloadFragmentCount) + return std::nullopt; + + if (requirePackedRunOffsetCountMatch) { + if (combined->layout.payloadFragmentCount != keys.size() * combined->layout.fragmentsPerLogicalSlot) + return std::nullopt; + } + if (failed(finalizeProjectedTransferDescriptor(targetClass.op, *combined))) + return std::nullopt; + return combined; +} + +bool haveSameDestinationClasses(MaterializerState& state, ArrayRef keys) { + if (keys.empty()) + return true; + + auto firstIt = state.producerDestClasses.find(keys.front()); + ArrayRef first = firstIt == state.producerDestClasses.end() ? ArrayRef() : firstIt->second; + for (ProducerKey key : keys.drop_front()) { + auto it = state.producerDestClasses.find(key); + ArrayRef current = it == state.producerDestClasses.end() ? ArrayRef() : it->second; + if (first.size() != current.size()) + return false; + for (auto [lhs, rhs] : llvm::zip(first, current)) + if (lhs != rhs) + return false; + } + return true; +} + +ArrayRef getDestinationClasses(MaterializerState& state, ProducerKey key) { + auto it = state.producerDestClasses.find(key); + if (it == state.producerDestClasses.end()) + return {}; + return it->second; +} + +std::optional getKnownMinimumIndexValue(Value value) { + if (std::optional constant = matchConstantIndexValue(value)) + return *constant; + + if (auto blockArg = dyn_cast(value)) { + if (blockArg.getArgNumber() == 0) { + if (auto loop = dyn_cast_or_null(blockArg.getOwner()->getParentOp())) + return matchConstantIndexValue(loop.getLowerBound()); + } + return std::nullopt; + } + + if (auto add = value.getDefiningOp()) { + std::optional lhs = getKnownMinimumIndexValue(add.getLhs()); + std::optional rhs = getKnownMinimumIndexValue(add.getRhs()); + if (lhs && rhs) + return *lhs + *rhs; + return std::nullopt; + } + + if (auto mul = value.getDefiningOp()) { + std::optional lhs = getKnownMinimumIndexValue(mul.getLhs()); + std::optional rhs = getKnownMinimumIndexValue(mul.getRhs()); + if (!lhs || !rhs) + return std::nullopt; + if (*lhs >= 0 && *rhs >= 0) + return *lhs * *rhs; + return std::nullopt; + } + + auto affineApply = value.getDefiningOp(); + if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) + return std::nullopt; + + SmallVector operands; + operands.reserve(affineApply.getMapOperands().size()); + for (Value operand : affineApply.getMapOperands()) { + std::optional minimum = getKnownMinimumIndexValue(operand); + if (!minimum) + return std::nullopt; + operands.push_back(IntegerAttr::get(IndexType::get(value.getContext()), *minimum)); + } + + SmallVector results; + if (failed(affineApply.getAffineMap().constantFold(operands, results)) || results.size() != 1) + return std::nullopt; + + auto intAttr = dyn_cast(results.front()); + if (!intAttr) + return std::nullopt; + return intAttr.getInt(); +} + +std::optional getKnownMinimumCommunicationChannelId(Operation* op) { + if (auto send = dyn_cast(op)) + return getKnownMinimumIndexValue(send.getChannelId()); + if (auto receive = dyn_cast(op)) + return getKnownMinimumIndexValue(receive.getChannelId()); + + std::optional minimum; + op->walk([&](Operation* nested) { + if (nested == op) + return; + std::optional nestedMinimum = getKnownMinimumCommunicationChannelId(nested); + if (!nestedMinimum) + return; + if (!minimum || *nestedMinimum < *minimum) + minimum = *nestedMinimum; + }); + return minimum; +} + +void setInsertionPointForScalarReceive(MaterializerState& state, + MaterializedClass& targetClass, + int64_t channelId) { + assert(!targetClass.isBatch && "scalar receive ordering expects a scalar target class"); + + for (Operation& op : *targetClass.body) { + if (op.hasTrait()) + break; + + std::optional existingChannel = getKnownMinimumCommunicationChannelId(&op); + if (existingChannel && *existingChannel > channelId) { + state.rewriter.setInsertionPoint(&op); + return; + } + } + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); +} + +// ----------------------------------------------------------------------------- +// Communication materialization helpers. +// ----------------------------------------------------------------------------- + +constexpr const char* kRaptorMinChannelIdAttr = "raptor.min_channel_id"; +constexpr const char* kRaptorMaterializerAttr = "raptor.materializer"; +constexpr const char* kRaptorCommTraceIdAttr = "raptor.comm_trace_id"; +constexpr const char* kRaptorCommTraceKindAttr = "raptor.comm_trace_kind"; +constexpr const char* kRaptorCommTracePhaseAttr = "raptor.comm_trace_phase"; +constexpr const char* kRaptorCommTraceClassIdAttr = "raptor.comm_trace_class_id"; +constexpr const char* kRaptorCommTraceClassKindAttr = "raptor.comm_trace_class_kind"; +constexpr const char* kRaptorCommTraceBlockOrdinalAttr = "raptor.comm_trace_block_ordinal"; +constexpr const char* kRaptorCommTracePayloadAttr = "raptor.comm_trace_payload"; +constexpr const char* kRaptorCommTraceMessagesAttr = "raptor.comm_trace_messages"; +constexpr const char* kRaptorCommTracePrevOpAttr = "raptor.comm_trace_prev_op"; +constexpr const char* kRaptorCommTraceNextOpAttr = "raptor.comm_trace_next_op"; + +int64_t getMinimumChannelId(ArrayRef channelIds) { + assert(!channelIds.empty() && "expected at least one channel id"); + int64_t minChannelId = channelIds.front(); + for (int64_t channelId : channelIds.drop_front()) + if (channelId < minChannelId) + minChannelId = channelId; + return minChannelId; +} + +SmallVector getScalarSendChannelOrder(const MessageVector& messages) { + SmallVector order; + order.reserve(messages.size()); + for (size_t i = 0, e = messages.size(); i < e; ++i) + order.push_back(i); + + llvm::sort(order, [&](size_t lhs, size_t rhs) { + if (messages.channelIds[lhs] != messages.channelIds[rhs]) + return messages.channelIds[lhs] < messages.channelIds[rhs]; + if (messages.sourceCoreIds[lhs] != messages.sourceCoreIds[rhs]) + return messages.sourceCoreIds[lhs] < messages.sourceCoreIds[rhs]; + return messages.targetCoreIds[lhs] < messages.targetCoreIds[rhs]; + }); + return order; +} + +MessageVector reorderMessages(const MessageVector& messages, ArrayRef order) { + MessageVector reordered; + reordered.channelIds.reserve(messages.size()); + reordered.sourceCoreIds.reserve(messages.size()); + reordered.targetCoreIds.reserve(messages.size()); + for (size_t index : order) + reordered.append(messages.channelIds[index], messages.sourceCoreIds[index], messages.targetCoreIds[index]); + return reordered; +} + +MessageVector reorderScalarSendMessagesByChannel(const MessageVector& messages) { + return reorderMessages(messages, getScalarSendChannelOrder(messages)); +} + +ProjectedTransferDescriptor reorderProjectedDescriptorByMessageOrder(const ProjectedTransferDescriptor& descriptor, + ArrayRef order) { + ProjectedTransferDescriptor reordered = descriptor; + size_t payloadFragmentCount = static_cast(descriptor.layout.payloadFragmentCount); + reordered.fragmentOffsets.clear(); + reordered.fragmentOffsets.reserve(descriptor.fragmentOffsets.size()); + for (size_t messageIndex : order) { + size_t offset = messageIndex * payloadFragmentCount; + for (size_t fragmentIndex = 0; fragmentIndex < payloadFragmentCount; ++fragmentIndex) + reordered.fragmentOffsets.push_back(descriptor.fragmentOffsets[offset + fragmentIndex]); + } + reordered.fragmentOffsetsByDim.clear(); + return reordered; +} + + +Operation* getPayloadDefiningOpInClassBlock(Value payload, MaterializedClass& materializedClass) { + Operation* definingOp = payload.getDefiningOp(); + if (!definingOp || definingOp->getBlock() != materializedClass.body) + return nullptr; + return definingOp; +} + +Operation* findScalarCommunicationInsertionPoint(MaterializedClass& materializedClass, + int64_t minChannelId, + Operation* lowerBound = nullptr) { + Operation* terminator = materializedClass.body->getTerminator(); + bool afterLowerBound = lowerBound == nullptr; + + for (Operation& op : *materializedClass.body) { + if (&op == terminator) + break; + + if (!afterLowerBound) { + if (&op == lowerBound) + afterLowerBound = true; + continue; + } + + if (&op == lowerBound) + continue; + + auto existingMinChannel = op.getAttrOfType(kRaptorMinChannelIdAttr); + if (existingMinChannel && existingMinChannel.getInt() > minChannelId) + return &op; + } + + return terminator; +} + +void setInsertionPointForScalarCommunication(MaterializerState& state, + MaterializedClass& materializedClass, + int64_t minChannelId, + Operation* lowerBound = nullptr) { + state.rewriter.setInsertionPoint( + findScalarCommunicationInsertionPoint(materializedClass, minChannelId, lowerBound)); +} + +constexpr const char kRaptorCommOrderAttr[] = "raptor.comm_order"; + +int64_t computeBlockingCommunicationOrderKey(int32_t sourceCoreId, int32_t targetCoreId, int64_t channelId) { + int64_t lowCore = std::min(sourceCoreId, targetCoreId); + int64_t highCore = std::max(sourceCoreId, targetCoreId); + int64_t directionPhase = sourceCoreId <= targetCoreId ? 0 : 1; + return (((lowCore * 1000000LL + highCore) * 2LL + directionPhase) * 1000000000LL) + channelId; +} + +int64_t getMinimumBlockingCommunicationOrderKey(const MessageVector& messages) { + assert(!messages.empty() && "expected at least one message"); + int64_t best = computeBlockingCommunicationOrderKey( + messages.sourceCoreIds.front(), messages.targetCoreIds.front(), messages.channelIds.front()); + for (size_t index = 1, end = messages.size(); index < end; ++index) { + best = std::min(best, computeBlockingCommunicationOrderKey( + messages.sourceCoreIds[index], messages.targetCoreIds[index], messages.channelIds[index])); + } + return best; +} + +Operation* findScalarCommunicationInsertionPointByOrder(MaterializedClass& materializedClass, + int64_t orderKey, + int64_t minChannelId, + Operation* lowerBound = nullptr) { + Operation* terminator = materializedClass.body->getTerminator(); + bool afterLowerBound = lowerBound == nullptr; + + for (Operation& op : *materializedClass.body) { + if (&op == terminator) + break; + + if (!afterLowerBound) { + if (&op == lowerBound) + afterLowerBound = true; + continue; + } + + if (&op == lowerBound) + continue; + + if (auto existingOrder = op.getAttrOfType(kRaptorCommOrderAttr)) { + if (existingOrder.getInt() > orderKey) + return &op; + continue; + } + + auto existingMinChannel = op.getAttrOfType(kRaptorMinChannelIdAttr); + if (existingMinChannel && existingMinChannel.getInt() > minChannelId) + return &op; + } + + return terminator; +} + +void setInsertionPointForScalarCommunicationOrder(MaterializerState& state, + MaterializedClass& materializedClass, + int64_t orderKey, + int64_t minChannelId, + Operation* lowerBound = nullptr) { + if (!pimMaterializeScalarFanoutGlobalOrder) { + setInsertionPointForScalarCommunication(state, materializedClass, minChannelId, lowerBound); + return; + } + + state.rewriter.setInsertionPoint( + findScalarCommunicationInsertionPointByOrder(materializedClass, orderKey, minChannelId, lowerBound)); +} + +void markScalarCommunication(Operation* op, int64_t minChannelId, StringRef materializer = StringRef()) { + if (!op) + return; + op->setAttr(kRaptorMinChannelIdAttr, + IntegerAttr::get(IndexType::get(op->getContext()), minChannelId)); + if (!materializer.empty()) + op->setAttr(kRaptorMaterializerAttr, StringAttr::get(op->getContext(), materializer)); +} + +void markScalarCommunicationOrder(Operation* op, int64_t orderKey) { + if (!op) + return; + op->setAttr(kRaptorCommOrderAttr, IntegerAttr::get(IndexType::get(op->getContext()), orderKey)); +} + +std::optional getOperationOrdinalInBlock(Operation* op) { + if (!op || !op->getBlock()) + return std::nullopt; + + int64_t ordinal = 0; + for (Operation& candidate : *op->getBlock()) { + if (&candidate == op) + return ordinal; + ++ordinal; + } + return std::nullopt; +} + +std::string formatOperationForTrace(Operation* op) { + if (!op) + return ""; + + std::string text; + llvm::raw_string_ostream os(text); + os << op->getName().getStringRef(); + if (auto ordinal = getOperationOrdinalInBlock(op)) + os << "@" << *ordinal; + return os.str(); +} + +std::string formatValueForTrace(Value value, Block* localBody) { + if (!value) + return ""; + + std::string text; + llvm::raw_string_ostream os(text); + if (auto arg = dyn_cast(value)) { + os << "block_arg#" << arg.getArgNumber(); + return os.str(); + } + + Operation* definingOp = value.getDefiningOp(); + if (!definingOp) { + os << "external"; + return os.str(); + } + + os << definingOp->getName().getStringRef(); + if (definingOp->getBlock() == localBody) { + if (auto ordinal = getOperationOrdinalInBlock(definingOp)) + os << "@" << *ordinal; + } + else { + os << "@external-block"; + } + return os.str(); +} + +std::string formatClassForTrace(const MaterializedClass& materializedClass) { + std::string text; + llvm::raw_string_ostream os(text); + os << (materializedClass.isBatch ? "batch" : "scalar") << " class " << materializedClass.id << " cpus=["; + for (auto [index, cpu] : llvm::enumerate(materializedClass.cpus)) { + if (index) + os << ","; + os << cpu; + } + os << "]"; + return os.str(); +} + +std::string formatMessagesForTrace(const MessageVector& messages, unsigned maxMessages = 8) { + std::string text; + llvm::raw_string_ostream os(text); + os << "count=" << messages.size() << " ["; + unsigned limit = std::min(maxMessages, messages.size()); + for (unsigned index = 0; index < limit; ++index) { + if (index) + os << "; "; + os << "c" << messages.channelIds[index] << ":" << messages.sourceCoreIds[index] + << "->" << messages.targetCoreIds[index]; + } + if (messages.size() > limit) + os << "; ..."; + os << "]"; + return os.str(); +} + +void annotateCommunicationMaterialization(MaterializerState& state, + MaterializedClass& materializedClass, + Operation* op, + StringRef kind, + StringRef materializer, + StringRef phase, + std::optional minChannelId, + std::optional orderKey, + Value payload = Value(), + const MessageVector* messages = nullptr) { + if (!op) + return; + + MLIRContext* context = op->getContext(); + int64_t traceId = state.nextCommunicationTraceId++; + auto indexType = IndexType::get(context); + op->setAttr(kRaptorCommTraceIdAttr, IntegerAttr::get(indexType, traceId)); + op->setAttr(kRaptorCommTraceKindAttr, StringAttr::get(context, kind)); + op->setAttr(kRaptorCommTracePhaseAttr, StringAttr::get(context, phase)); + op->setAttr(kRaptorCommTraceClassIdAttr, IntegerAttr::get(indexType, materializedClass.id)); + op->setAttr(kRaptorCommTraceClassKindAttr, + StringAttr::get(context, materializedClass.isBatch ? "batch" : "scalar")); + if (!materializer.empty()) + op->setAttr(kRaptorMaterializerAttr, StringAttr::get(context, materializer)); + if (minChannelId) + op->setAttr(kRaptorMinChannelIdAttr, IntegerAttr::get(indexType, *minChannelId)); + if (orderKey) + op->setAttr(kRaptorCommOrderAttr, IntegerAttr::get(indexType, *orderKey)); + if (auto ordinal = getOperationOrdinalInBlock(op)) + op->setAttr(kRaptorCommTraceBlockOrdinalAttr, IntegerAttr::get(indexType, *ordinal)); + op->setAttr(kRaptorCommTracePayloadAttr, + StringAttr::get(context, formatValueForTrace(payload, materializedClass.body))); + if (messages) + op->setAttr(kRaptorCommTraceMessagesAttr, StringAttr::get(context, formatMessagesForTrace(*messages))); + + Operation* prev = op->getPrevNode(); + Operation* next = op->getNextNode(); + op->setAttr(kRaptorCommTracePrevOpAttr, StringAttr::get(context, formatOperationForTrace(prev))); + op->setAttr(kRaptorCommTraceNextOpAttr, StringAttr::get(context, formatOperationForTrace(next))); + + if (!pimTraceCommunicationMaterialization) + return; + + llvm::errs() << "[raptor:comm-materializer] #" << traceId << " " << kind + << " via " << materializer << " phase=" << phase << " " + << formatClassForTrace(materializedClass); + if (minChannelId) + llvm::errs() << " min_channel=" << *minChannelId; + if (orderKey) + llvm::errs() << " order=" << *orderKey; + if (auto ordinal = getOperationOrdinalInBlock(op)) + llvm::errs() << " block_ordinal=" << *ordinal; + llvm::errs() << " payload=" << formatValueForTrace(payload, materializedClass.body); + if (messages) + llvm::errs() << " messages=" << formatMessagesForTrace(*messages); + llvm::errs() << " prev=" << formatOperationForTrace(prev) + << " next=" << formatOperationForTrace(next) << "\n"; +} + +void setInsertionPointForEarlyCommunication(MaterializerState& state, MaterializedClass& materializedClass) { + auto lateIt = state.firstLateCommunicationOps.find(materializedClass.id); + if (lateIt != state.firstLateCommunicationOps.end() && lateIt->second && lateIt->second->getBlock()) { + state.rewriter.setInsertionPoint(lateIt->second); + return; + } + + state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); +} + +void setInsertionPointForLateCommunication(MaterializerState& state, MaterializedClass& materializedClass) { + state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); +} + + +Operation* findLateScalarCommunicationInsertionPoint(MaterializerState& state, + MaterializedClass& materializedClass, + int64_t minChannelId) { + Operation* terminator = materializedClass.body->getTerminator(); + auto lateIt = state.firstLateCommunicationOps.find(materializedClass.id); + Operation* firstLate = lateIt == state.firstLateCommunicationOps.end() ? nullptr : lateIt->second; + if (!firstLate || firstLate->getBlock() != materializedClass.body) + return terminator; + + bool inLateRegion = false; + for (Operation& op : *materializedClass.body) { + if (&op == terminator) + break; + + if (!inLateRegion) { + if (&op == firstLate) + inLateRegion = true; + else + continue; + } + + auto existingMinChannel = op.getAttrOfType(kRaptorMinChannelIdAttr); + if (existingMinChannel && existingMinChannel.getInt() > minChannelId) + return &op; + } + + return terminator; +} + +void setInsertionPointForLateScalarCommunication(MaterializerState& state, + MaterializedClass& materializedClass, + int64_t minChannelId) { + state.rewriter.setInsertionPoint( + findLateScalarCommunicationInsertionPoint(state, materializedClass, minChannelId)); +} + +void rememberLateCommunicationOp(MaterializerState& state, MaterializedClass& materializedClass, Operation* op) { + if (!op || op->getBlock() != materializedClass.body) + return; + + Operation*& firstLate = state.firstLateCommunicationOps[materializedClass.id]; + if (!firstLate || firstLate->getBlock() != materializedClass.body || op->isBeforeInBlock(firstLate)) + firstLate = op; +} + + + +constexpr const char kMinCommunicationChannelIdAttr[] = "raptor.min_channel_id"; + +std::optional getConstantIndexValue(Value value) { + APInt constant; + if (matchPattern(value, m_ConstantInt(&constant))) + return constant.getSExtValue(); + return std::nullopt; +} + +std::optional getCommunicationChannelId(Operation& op) { + if (auto attr = op.getAttrOfType(kMinCommunicationChannelIdAttr)) + return attr.getInt(); + + if (auto send = dyn_cast(&op)) + return getConstantIndexValue(send.getChannelId()); + if (auto receive = dyn_cast(&op)) + return getConstantIndexValue(receive.getChannelId()); + + return std::nullopt; +} + +int64_t getMinimumCommunicationChannelId(const MessageVector& messages) { + assert(!messages.empty() && "expected at least one message"); + return *std::min_element(messages.channelIds.begin(), messages.channelIds.end()); +} + +void markCommunicationChannelId(Operation* op, int64_t channelId) { + if (!op) + return; + op->setAttr(kMinCommunicationChannelIdAttr, + IntegerAttr::get(IntegerType::get(op->getContext(), 64), channelId)); +} + +Operation* getSameBlockDefiningOp(Value value, Block* block) { + Operation* definingOp = value.getDefiningOp(); + if (!definingOp || definingOp->getBlock() != block) + return nullptr; + return definingOp; +} + + +bool valueDependsOnChannelReceive(Value root) { + SmallVector worklist; + DenseSet visitedValues; + DenseSet visitedOps; + worklist.push_back(root); + + auto visitOperand = [&](Value value) { + if (value && visitedValues.insert(value).second) + worklist.push_back(value); + }; + + while (!worklist.empty()) { + Value value = worklist.pop_back_val(); + Operation* definingOp = value.getDefiningOp(); + if (!definingOp || !visitedOps.insert(definingOp).second) + continue; + + if (isa(definingOp)) + return true; + + for (Value operand : definingOp->getOperands()) + visitOperand(operand); + + for (Region& region : definingOp->getRegions()) { + for (Block& block : region) { + for (Operation& nested : block) { + for (Value operand : nested.getOperands()) + visitOperand(operand); + } + } + } + } + + return false; +} + +bool shouldDelayScalarSendUntilAfterReceives(Value payload, int32_t sourceCoreId, int32_t targetCoreId) { + if (sourceCoreId <= targetCoreId) + return false; + return valueDependsOnChannelReceive(payload); +} + +void partitionScalarMessagesByReceiveDependency(Value payload, + const MessageVector& messages, + MessageVector& earlyMessages, + MessageVector& lateMessages) { + for (size_t i = 0, e = messages.size(); i < e; ++i) { + MessageVector& bucket = shouldDelayScalarSendUntilAfterReceives( + payload, messages.sourceCoreIds[i], messages.targetCoreIds[i]) + ? lateMessages + : earlyMessages; + bucket.append(messages.channelIds[i], messages.sourceCoreIds[i], messages.targetCoreIds[i]); + } +} + +void setInsertionPointForScalarSend(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + int64_t minChannelId, + bool late) { + if (late) { + setInsertionPointForLateScalarCommunication(state, sourceClass, minChannelId); + return; + } + + setInsertionPointForScalarCommunication( + state, sourceClass, minChannelId, getPayloadDefiningOpInClassBlock(payload, sourceClass)); +} + + +void appendScalarSend(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + int64_t channelId, + int32_t sourceCoreId, + int32_t targetCoreId, + Location loc) { + assert(!sourceClass.isBatch && "scalar send helper expects a scalar source class"); + + bool late = shouldDelayScalarSendUntilAfterReceives(payload, sourceCoreId, targetCoreId); + int64_t orderKey = computeBlockingCommunicationOrderKey(sourceCoreId, targetCoreId, channelId); + if (pimMaterializeScalarFanoutGlobalOrder) + setInsertionPointForScalarCommunicationOrder( + state, sourceClass, orderKey, channelId, getPayloadDefiningOpInClassBlock(payload, sourceClass)); + else + setInsertionPointForScalarSend(state, sourceClass, payload, channelId, late); + Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, channelId); + Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, sourceCoreId); + Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, targetCoreId); + auto send = SpatChannelSendOp::create( + state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload); + markScalarCommunication(send.getOperation(), channelId, "appendScalarSend"); + markScalarCommunicationOrder(send.getOperation(), orderKey); + MessageVector traceMessages; + traceMessages.append(channelId, sourceCoreId, targetCoreId); + annotateCommunicationMaterialization(state, + sourceClass, + send.getOperation(), + "send", + "appendScalarSend", + late ? "late" : (pimMaterializeScalarFanoutGlobalOrder ? "global" : "early"), + channelId, + orderKey, + payload, + &traceMessages); + if (late && !pimMaterializeScalarFanoutGlobalOrder) + rememberLateCommunicationOp(state, sourceClass, send.getOperation()); +} + +LogicalResult emitScalarSendLoopAtInsertionPoint(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const MessageVector& messages, + int64_t minChannelId, + int64_t orderKey, + Location loc) { + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); + Value upperBound = + getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(messages.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); + + auto sendLoop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {}, + [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { + Value channelId = createIndexedChannelId(state, sourceClass.op, messages, index, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, messages, index, loc); + Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, messages, index, loc); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); + return success(); + }); + if (failed(sendLoop)) + return failure(); + markScalarCommunication(sendLoop->loop.getOperation(), minChannelId, "appendScalarSendLoop"); + markScalarCommunicationOrder(sendLoop->loop.getOperation(), orderKey); + annotateCommunicationMaterialization(state, + sourceClass, + sendLoop->loop.getOperation(), + "send-loop", + "appendScalarSendLoop", + "loop", + minChannelId, + orderKey, + payload, + &messages); + return success(); +} + +LogicalResult appendScalarSendLoop(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const MessageVector& messages, + Location loc) { + assert(!sourceClass.isBatch && "scalar send loop expects a scalar source class"); + assert(messages.size() > 1 && "send loop is only useful for multiple sends"); + assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); + + MessageVector orderedMessages = reorderScalarSendMessagesByChannel(messages); + if (pimMaterializeScalarFanoutGlobalOrder) { + for (size_t index = 0, end = orderedMessages.size(); index < end; ++index) + appendScalarSend(state, + sourceClass, + payload, + orderedMessages.channelIds[index], + orderedMessages.sourceCoreIds[index], + orderedMessages.targetCoreIds[index], + loc); + return success(); + } + + int64_t minChannelId = getMinimumChannelId(orderedMessages.channelIds); + int64_t orderKey = getMinimumBlockingCommunicationOrderKey(orderedMessages); + setInsertionPointForScalarCommunicationOrder( + state, sourceClass, orderKey, minChannelId, getPayloadDefiningOpInClassBlock(payload, sourceClass)); + return emitScalarSendLoopAtInsertionPoint(state, sourceClass, payload, orderedMessages, minChannelId, orderKey, loc); +} + + +FailureOr buildProjectedPackedPayload(MaterializerState& state, + MaterializedClass& targetClass, + Value fullPayload, + const ProjectedTransferDescriptor& descriptor, + Value messageIndex, + Location loc) { + if (failed(verifyProjectedTransferDescriptor(targetClass.op, descriptor))) + return failure(); + if (descriptor.layout.payloadFragmentCount == 1) + return targetClass.op->emitError("projected packed payload builder expects a packed payload"); + + Value init = tensor::EmptyOp::create( + state.rewriter, loc, descriptor.payloadType.getShape(), descriptor.payloadType.getElementType()) + .getResult(); + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, descriptor.layout.payloadFragmentCount); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {init}, + [&](OpBuilder&, Location, Value fragmentIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value acc = iterArgs.front(); + Value payloadFragmentCount = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, descriptor.layout.payloadFragmentCount); + FailureOr localMessageIndex = rematerializeIndexValueInClass(state, targetClass, messageIndex, loc); + if (failed(localMessageIndex)) + return failure(); + Value flatBase = arith::MulIOp::create(state.rewriter, loc, *localMessageIndex, payloadFragmentCount).getResult(); + Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult(); + + FailureOr> fragmentOffsets = + buildProjectedFragmentOffsetsInClass(state, targetClass, descriptor, flatIndex, loc); + if (failed(fragmentOffsets)) + return failure(); + FailureOr fragment = createStaticExtractSliceInClass( + state, targetClass, loc, fullPayload, *fragmentOffsets, descriptor.layout.fragmentShape); + if (failed(fragment)) + return failure(); + + FailureOr packedOffset = scaleIndexByDim0SizeInClass( + state, targetClass, fragmentIndex, descriptor.layout.fragmentType.getDimSize(0), loc); + if (failed(packedOffset)) + return failure(); + FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, *fragment, acc, *packedOffset); + if (failed(next)) + return failure(); + yielded.push_back(*next); + return success(); + }); + if (failed(loop)) + return failure(); + return loop->results.front(); +} + +FailureOr buildProjectedPayloadForMessage(MaterializerState& state, + MaterializedClass& targetClass, + Value fullPayload, + const ProjectedTransferDescriptor& descriptor, + Value messageIndex, + Location loc) { + if (failed(verifyProjectedTransferDescriptor(targetClass.op, descriptor))) + return failure(); + + if (descriptor.layout.payloadFragmentCount == 1) { + FailureOr> fragmentOffsets = + buildProjectedFragmentOffsetsInClass(state, targetClass, descriptor, messageIndex, loc); + if (failed(fragmentOffsets)) + return failure(); + return createStaticExtractSliceInClass( + state, targetClass, loc, fullPayload, *fragmentOffsets, descriptor.layout.fragmentShape); + } + + return buildProjectedPackedPayload(state, targetClass, fullPayload, descriptor, messageIndex, loc); +} + +LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const ProjectedTransferDescriptor& descriptor, + const MessageVector& messages, + Location loc) { + assert(!sourceClass.isBatch && "projected scalar send expects scalar source class"); + assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); + + SmallVector messageOrder = getScalarSendChannelOrder(messages); + MessageVector orderedMessages = reorderMessages(messages, messageOrder); + ProjectedTransferDescriptor orderedDescriptor = reorderProjectedDescriptorByMessageOrder(descriptor, messageOrder); + if (failed(finalizeProjectedTransferDescriptor(sourceClass.op, orderedDescriptor))) + return failure(); + if (failed(verifyProjectedSendDescriptor(sourceClass.op, orderedDescriptor, orderedMessages))) + return failure(); + + int64_t minChannelId = getMinimumChannelId(orderedMessages.channelIds); + int64_t orderKey = getMinimumBlockingCommunicationOrderKey(orderedMessages); + setInsertionPointForScalarCommunicationOrder( + state, sourceClass, orderKey, minChannelId, getPayloadDefiningOpInClassBlock(payload, sourceClass)); + + if (orderedMessages.size() == 1 || pimMaterializeScalarFanoutGlobalOrder) { + for (size_t index = 0, end = orderedMessages.size(); index < end; ++index) { + int64_t channel = orderedMessages.channelIds[index]; + int32_t sourceCore = orderedMessages.sourceCoreIds[index]; + int32_t targetCore = orderedMessages.targetCoreIds[index]; + int64_t localOrderKey = computeBlockingCommunicationOrderKey(sourceCore, targetCore, channel); + setInsertionPointForScalarCommunicationOrder( + state, sourceClass, localOrderKey, channel, getPayloadDefiningOpInClassBlock(payload, sourceClass)); + + Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, channel); + Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, sourceCore); + Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, targetCore); + Value messageIndex = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(index)); + FailureOr sendPayload = + buildProjectedPayloadForMessage(state, sourceClass, payload, orderedDescriptor, messageIndex, loc); + if (failed(sendPayload)) + return failure(); + auto send = SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, *sendPayload); + markScalarCommunication(send.getOperation(), channel, "appendProjectedScalarSendLoop.single"); + markScalarCommunicationOrder(send.getOperation(), localOrderKey); + MessageVector traceMessages; + traceMessages.append(channel, sourceCore, targetCore); + annotateCommunicationMaterialization(state, + sourceClass, + send.getOperation(), + "send", + "appendProjectedScalarSendLoop.single", + "projected-single", + channel, + localOrderKey, + *sendPayload, + &traceMessages); + } + return success(); + } + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); + Value upperBound = + getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(orderedMessages.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); + + auto projectedSendLoop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {}, + [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { + Value channelId = createIndexedChannelId(state, sourceClass.op, orderedMessages, index, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, orderedMessages, index, loc); + Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, orderedMessages, index, loc); + FailureOr sendPayload = + buildProjectedPayloadForMessage(state, sourceClass, payload, orderedDescriptor, index, loc); + if (failed(sendPayload)) + return failure(); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, *sendPayload); + return success(); + }); + if (failed(projectedSendLoop)) + return failure(); + markScalarCommunication(projectedSendLoop->loop.getOperation(), minChannelId, "appendProjectedScalarSendLoop.loop"); + markScalarCommunicationOrder(projectedSendLoop->loop.getOperation(), orderKey); + annotateCommunicationMaterialization(state, + sourceClass, + projectedSendLoop->loop.getOperation(), + "send-loop", + "appendProjectedScalarSendLoop.loop", + "projected-loop", + minChannelId, + orderKey, + payload, + &orderedMessages); + return success(); +} + + +LogicalResult appendSend(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const MessageVector& messages, + Location loc) { + assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); + assert(!messages.empty() && "expected at least one send"); + + if (sourceClass.isBatch) { + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + + Value channelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc); + Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc); + Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.targetCoreIds, loc); + auto send = SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); + int64_t minChannelId = getMinimumChannelId(messages.channelIds); + int64_t orderKey = getMinimumBlockingCommunicationOrderKey(messages); + markScalarCommunication(send.getOperation(), minChannelId, "appendSend.batch"); + markScalarCommunicationOrder(send.getOperation(), orderKey); + annotateCommunicationMaterialization(state, + sourceClass, + send.getOperation(), + "send", + "appendSend.batch", + "batch-lane-indexed", + minChannelId, + orderKey, + payload, + &messages); + return success(); + } + + if (messages.size() == 1) { + appendScalarSend(state, + sourceClass, + payload, + messages.channelIds.front(), + messages.sourceCoreIds.front(), + messages.targetCoreIds.front(), + loc); + return success(); + } + + return appendScalarSendLoop(state, sourceClass, payload, messages, loc); +} + +Value appendScalarReceive(MaterializerState& state, + MaterializedClass& targetClass, + Type type, + int64_t channelId, + int32_t sourceCoreId, + int32_t targetCoreId, + Location loc, + bool lateReceive = false) { + assert(!targetClass.isBatch && "scalar receive helper expects a scalar target class"); + + int64_t orderKey = computeBlockingCommunicationOrderKey(sourceCoreId, targetCoreId, channelId); + if (lateReceive) + setInsertionPointForLateScalarCommunication(state, targetClass, channelId); + else + setInsertionPointForScalarCommunicationOrder(state, targetClass, orderKey, channelId); + Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, channelId); + Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, sourceCoreId); + Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, targetCoreId); + auto receive = SpatChannelReceiveOp::create( + state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue); + markScalarCommunication(receive.getOperation(), channelId, + lateReceive ? "appendScalarReceive.late" : "appendScalarReceive"); + markScalarCommunicationOrder(receive.getOperation(), orderKey); + MessageVector traceMessages; + traceMessages.append(channelId, sourceCoreId, targetCoreId); + annotateCommunicationMaterialization(state, + targetClass, + receive.getOperation(), + "receive", + lateReceive ? "appendScalarReceive.late" : "appendScalarReceive", + lateReceive ? "late" : (pimMaterializeScalarFanoutGlobalOrder ? "global" : "early"), + channelId, + orderKey, + Value(), + &traceMessages); + return receive.getOutput(); +} + + +Value appendReceive( + MaterializerState& state, + MaterializedClass& targetClass, + Type type, + const MessageVector& messages, + Location loc, + bool lateReceive = false) { + assert(succeeded(messages.verify(targetClass.op)) && "message metadata is inconsistent"); + assert(!messages.empty() && "expected at least one receive"); + + if (lateReceive) + setInsertionPointForLateScalarCommunication(state, targetClass, getMinimumChannelId(messages.channelIds)); + else + setInsertionPointForEarlyCommunication(state, targetClass); + + if (targetClass.isBatch) { + Value channelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc); + Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc); + Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc); + auto receive = SpatChannelReceiveOp::create(state.rewriter, loc, type, channelId, sourceCoreId, targetCoreId); + int64_t minChannelId = getMinimumChannelId(messages.channelIds); + int64_t orderKey = getMinimumBlockingCommunicationOrderKey(messages); + markScalarCommunication(receive.getOperation(), minChannelId, "appendReceive.batch"); + markScalarCommunicationOrder(receive.getOperation(), orderKey); + annotateCommunicationMaterialization(state, + targetClass, + receive.getOperation(), + "receive", + "appendReceive.batch", + lateReceive ? "late-batch" : "early-batch", + minChannelId, + orderKey, + Value(), + &messages); + return receive.getOutput(); + } + + assert(messages.size() == 1 && "scalar target class can only receive one message at a time"); + return appendScalarReceive(state, + targetClass, + type, + messages.channelIds.front(), + messages.sourceCoreIds.front(), + messages.targetCoreIds.front(), + loc, + lateReceive); +} + +Value appendScalarReceiveAtCurrentInsertionPoint(MaterializerState& state, + MaterializedClass& targetClass, + Type type, + int64_t channelId, + int32_t sourceCoreId, + int32_t targetCoreId, + Location loc) { + assert(!targetClass.isBatch && "demand scalar receive expects a scalar target class"); + + int64_t orderKey = computeBlockingCommunicationOrderKey(sourceCoreId, targetCoreId, channelId); + Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, channelId); + Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, sourceCoreId); + Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, targetCoreId); + auto receive = SpatChannelReceiveOp::create( + state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue); + markScalarCommunication(receive.getOperation(), channelId, "appendScalarReceive.demand"); + markScalarCommunicationOrder(receive.getOperation(), orderKey); + MessageVector traceMessages; + traceMessages.append(channelId, sourceCoreId, targetCoreId); + annotateCommunicationMaterialization(state, + targetClass, + receive.getOperation(), + "receive", + "appendScalarReceive.demand", + "demand", + channelId, + orderKey, + Value(), + &traceMessages); + return receive.getOutput(); +} + +std::optional lookupPendingScalarReceiveIndex(MaterializerState& state, + ProducerKey key, + ClassId targetClassId) { + auto keyIt = state.pendingScalarReceiveLookup.find(key); + if (keyIt == state.pendingScalarReceiveLookup.end()) + return std::nullopt; + + auto classIt = keyIt->second.find(targetClassId); + if (classIt == keyIt->second.end()) + return std::nullopt; + return classIt->second; +} + +void recordPendingScalarReceive(MaterializerState& state, + ClassId targetClassId, + ArrayRef keys, + Type receiveType, + const MessageVector& messages, + Location loc) { + if (keys.empty()) + return; + + if (lookupPendingScalarReceiveIndex(state, keys.front(), targetClassId)) + return; + + size_t recordIndex = state.pendingScalarReceives.size(); + state.pendingScalarReceives.emplace_back(keys, targetClassId, receiveType, messages, loc); + + for (ProducerKey key : keys) + state.pendingScalarReceiveLookup[key][targetClassId] = recordIndex; +} + +FailureOr materializePendingScalarReceive(MaterializerState& state, + MaterializedClass& targetClass, + size_t recordIndex, + Location loc) { + if (recordIndex >= state.pendingScalarReceives.size()) + return targetClass.op->emitError("pending scalar receive index is out of bounds"); + + PendingScalarReceiveRecord& record = state.pendingScalarReceives[recordIndex]; + if (record.targetClassId != targetClass.id) + return targetClass.op->emitError("pending scalar receive target class mismatch"); + + if (record.materialized) + return record.value; + + if (targetClass.isBatch) + return targetClass.op->emitError("pending scalar receive cannot materialize into a batch class"); + if (record.messages.size() != 1) + return targetClass.op->emitError("pending scalar receive expected exactly one scalar message"); + + Location receiveLoc = loc; + Value received = appendScalarReceiveAtCurrentInsertionPoint(state, + targetClass, + record.receiveType, + record.messages.channelIds.front(), + record.messages.sourceCoreIds.front(), + record.messages.targetCoreIds.front(), + receiveLoc); + record.materialized = true; + record.value = received; + + for (ProducerKey key : record.keys) + state.availableValues.record(key, targetClass.id, received); + + return received; +} + + +LogicalResult materializePendingScalarReceivesForWholeBatchInput(MaterializerState& state, + MaterializedClass& targetClass, + ProducerKey wholeBatchKey, + Location loc) { + if (targetClass.isBatch || !isWholeBatchProducerKey(wholeBatchKey)) + return success(); + + SmallVector pendingIndices; + for (auto [recordIndex, record] : llvm::enumerate(state.pendingScalarReceives)) { + if (record.targetClassId != targetClass.id || record.materialized) + continue; + + bool contributesToWholeBatch = llvm::any_of(record.keys, [&](ProducerKey fragmentKey) { + return containsProducerKey(wholeBatchKey, fragmentKey); + }); + if (contributesToWholeBatch) + pendingIndices.push_back(recordIndex); + } + + if (pendingIndices.empty()) + return success(); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + for (size_t recordIndex : pendingIndices) { + FailureOr received = materializePendingScalarReceive(state, targetClass, recordIndex, loc); + if (failed(received)) + return failure(); + } + + return success(); +} + +LogicalResult registerLazyPackedScalarReceives(MaterializerState& state, + MaterializedClass& sourceClass, + MaterializedClass& targetClass, + ArrayRef keys, + Type fragmentType, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds) { + if (!sourceClass.isBatch) + return sourceClass.op->emitError("lazy packed scalar receives expect a batch source class"); + + if (targetClass.isBatch) + return targetClass.op->emitError("lazy packed scalar receives expect a scalar target class"); + + if (keys.empty()) + return sourceClass.op->emitError("lazy packed scalar receive expects at least one producer key"); + + if (keys.size() != sourceClass.cpus.size()) + return sourceClass.op->emitError("lazy packed scalar receive expects one producer key per source lane"); + + MessageVector messages; + messages.append(channelIds, sourceCoreIds, targetCoreIds); + if (failed(messages.verify(targetClass.op))) + return failure(); + + if (keys.size() != messages.size()) + return targetClass.op->emitError("lazy packed scalar receive metadata is inconsistent"); + + auto rankedFragmentType = dyn_cast(fragmentType); + if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) + return targetClass.op->emitError("lazy packed scalar receive expects a static ranked fragment type"); + + if (failed(verifyPackableFragmentType( + targetClass.op, fragmentType, keys.size(), "cannot create lazy packed scalar receive type"))) + return failure(); + + Operation* sourceOp = keys.front().instance.op; + size_t resultIndex = keys.front().resultIndex; + + for (ProducerKey key : keys) { + if (key.instance.op != sourceOp || key.resultIndex != resultIndex) + return sourceClass.op->emitError("lazy packed scalar receive expects one producer result"); + + if (key.instance.laneCount != 1) + return sourceClass.op->emitError("lazy packed scalar receive expects one lane per producer key"); + } + + PackedScalarRunValue packedRun; + packedRun.targetClass = targetClass.id; + packedRun.sourceOp = sourceOp; + packedRun.resultIndex = resultIndex; + packedRun.kind = PackedScalarRunKind::DeferredReceive; + packedRun.fragmentType = rankedFragmentType; + + packedRun.messages = std::move(messages); + + PackedScalarRunSlot slot; + llvm::append_range(slot.keys, keys); + packedRun.slots.push_back(std::move(slot)); + + if (failed(validatePackedScalarRunMetadata(targetClass.op, packedRun))) + return failure(); + + state.availableValues.recordPackedRun(std::move(packedRun)); + return success(); +} + +struct ScalarSourceReceivePlan { + ClassId targetClass = 0; + MessageVector messages; + Type receiveType; + Operation* projectedExtractOp = nullptr; + ProjectedFragmentLayout projectedLayout; + std::optional projectedDescriptor; +}; + +struct ProjectedScalarSendGroup { + MessageVector messages; + ProjectedTransferDescriptor descriptor; +}; + +struct ScalarSourceFanoutPlan { + SmallVector receivePlans; + std::optional ordinaryMessages; + SmallVector projectedSendGroups; +}; + +bool hasSameProjectedSendCompatibility(const ProjectedTransferDescriptor& lhs, const ProjectedTransferDescriptor& rhs) { + return lhs.layout.fragmentType == rhs.layout.fragmentType && lhs.layout.fragmentShape == rhs.layout.fragmentShape + && lhs.layout.fragmentsPerLogicalSlot == rhs.layout.fragmentsPerLogicalSlot + && lhs.layout.payloadFragmentCount == rhs.layout.payloadFragmentCount + && lhs.layout.loopLowerBounds == rhs.layout.loopLowerBounds && lhs.layout.loopSteps == rhs.layout.loopSteps + && lhs.layout.loopTripCounts == rhs.layout.loopTripCounts && lhs.payloadType == rhs.payloadType; +} + +SmallVector collectDestinationClassesForKeys(MaterializerState& state, ArrayRef keys) { + SmallVector destinations; + + for (ProducerKey key : keys) + for (ClassId destinationClass : getDestinationClasses(state, key)) + destinations.push_back(destinationClass); + + llvm::sort(destinations); + destinations.erase(std::unique(destinations.begin(), destinations.end()), destinations.end()); + return destinations; +} + +FailureOr buildScalarSourceFanoutPlan(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + ArrayRef destinationClasses, + Value payload) { + assert(!sourceClass.isBatch && "scalar-source send planning expects a scalar source class"); + + auto sourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "scalar source core id"); + if (failed(sourceCpu)) + return failure(); + + ScalarSourceFanoutPlan fanoutPlan; + fanoutPlan.receivePlans.reserve(destinationClasses.size()); + + const auto getProjectedDescriptor = + [&](ClassId destinationClass) -> FailureOr> { + MaterializedClass& targetClass = state.classes[destinationClass]; + if (!targetClass.isBatch) { + bool hasAnyProjectedDescriptor = llvm::any_of(keys, [&](ProducerKey key) { + auto producerIt = state.projectedTransfers.find(key); + return producerIt != state.projectedTransfers.end() && producerIt->second.count(destinationClass) != 0; + }); + + std::optional descriptor = collectScalarTargetProjectedDescriptor( + state, targetClass, keys, /*requirePackedRunOffsetCountMatch=*/keys.size() > 1); + if (hasAnyProjectedDescriptor && !descriptor) + return targetClass.op->emitError("incomplete scalar projected transfer descriptor for local run"); + return descriptor; + } + + if (keys.size() != 1) + return std::optional {}; + + auto producerIt = state.projectedTransfers.find(keys.front()); + if (producerIt == state.projectedTransfers.end()) + return std::optional {}; + + auto descriptorIt = producerIt->second.find(destinationClass); + if (descriptorIt == producerIt->second.end()) + return std::optional {}; + + const ProjectedTransferDescriptor& descriptor = descriptorIt->second; + if (failed(verifyProjectedTransferDescriptor(targetClass.op, descriptor))) + return failure(); + if (descriptor.fragmentOffsets.size() + != targetClass.cpus.size() * static_cast(descriptor.layout.payloadFragmentCount)) + return targetClass.op->emitError("inconsistent batch projected transfer descriptor"); + + return std::optional {descriptor}; + }; + + for (ClassId destinationClass : destinationClasses) { + if (destinationClass == sourceClass.id) + continue; + + MaterializedClass& targetClass = state.classes[destinationClass]; + + ScalarSourceReceivePlan receivePlan; + receivePlan.targetClass = destinationClass; + receivePlan.receiveType = payload.getType(); + + auto appendMessage = [&](CpuId targetCpu) -> LogicalResult { + auto checkedTargetCpu = getCheckedCoreId(targetClass.op, targetCpu, "scalar target core id"); + if (failed(checkedTargetCpu)) + return failure(); + int64_t channelId = state.nextChannelId++; + + receivePlan.messages.append(channelId, *sourceCpu, *checkedTargetCpu); + return success(); + }; + + if (!targetClass.isBatch) { + if (failed(appendMessage(targetClass.cpus.front()))) + return failure(); + } + else { + for (CpuId targetCpu : targetClass.cpus) + if (failed(appendMessage(targetCpu))) + return failure(); + } + + FailureOr> descriptor = getProjectedDescriptor(destinationClass); + if (failed(descriptor)) + return failure(); + + if (*descriptor) { + const ProjectedTransferDescriptor& projectedDescriptor = **descriptor; + + if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType()) { + if (!fanoutPlan.ordinaryMessages) + fanoutPlan.ordinaryMessages = MessageVector {}; + fanoutPlan.ordinaryMessages->append( + receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds); + fanoutPlan.receivePlans.push_back(std::move(receivePlan)); + continue; + } + + receivePlan.receiveType = projectedDescriptor.payloadType; + receivePlan.projectedExtractOp = projectedDescriptor.extractOp; + receivePlan.projectedLayout = projectedDescriptor.layout; + receivePlan.projectedDescriptor = projectedDescriptor; + + auto groupIt = llvm::find_if(fanoutPlan.projectedSendGroups, [&](const ProjectedScalarSendGroup& group) { + return hasSameProjectedSendCompatibility(group.descriptor, projectedDescriptor); + }); + if (groupIt == fanoutPlan.projectedSendGroups.end()) { + ProjectedScalarSendGroup group; + group.descriptor.layout = projectedDescriptor.layout; + group.descriptor.payloadType = projectedDescriptor.payloadType; + fanoutPlan.projectedSendGroups.push_back(std::move(group)); + groupIt = std::prev(fanoutPlan.projectedSendGroups.end()); + } + + groupIt->messages.append( + receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds); + llvm::append_range(groupIt->descriptor.fragmentOffsets, projectedDescriptor.fragmentOffsets); + } + else { + if (!fanoutPlan.ordinaryMessages) + fanoutPlan.ordinaryMessages = MessageVector {}; + fanoutPlan.ordinaryMessages->append( + receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds); + } + + fanoutPlan.receivePlans.push_back(std::move(receivePlan)); + } + + for (ProjectedScalarSendGroup& group : fanoutPlan.projectedSendGroups) { + if (failed(finalizeProjectedTransferDescriptor(sourceClass.op, group.descriptor))) + return failure(); + if (failed(verifyProjectedSendDescriptor(sourceClass.op, group.descriptor, group.messages))) + return failure(); + } + + return fanoutPlan; +} + +LogicalResult emitScalarSourceFanoutSends(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const ScalarSourceFanoutPlan& plan, + Location loc) { + if (plan.ordinaryMessages && failed(appendSend(state, sourceClass, payload, *plan.ordinaryMessages, loc))) + return failure(); + + for (const ProjectedScalarSendGroup& group : plan.projectedSendGroups) + if (failed(appendProjectedScalarSendLoop(state, sourceClass, payload, group.descriptor, group.messages, loc))) + return failure(); + + return success(); +} + + +struct GloballyOrderedScalarFanoutEvent { + size_t receivePlanIndex = 0; + int64_t minChannelId = 0; + int64_t orderKey = 0; + int32_t minSourceCoreId = 0; + int32_t minTargetCoreId = 0; +}; + +GloballyOrderedScalarFanoutEvent makeGloballyOrderedScalarFanoutEvent(size_t receivePlanIndex, + const ScalarSourceReceivePlan& plan) { + assert(!plan.messages.empty() && "expected a communication event with at least one message"); + GloballyOrderedScalarFanoutEvent event; + event.receivePlanIndex = receivePlanIndex; + event.minChannelId = plan.messages.channelIds.front(); + event.orderKey = getMinimumBlockingCommunicationOrderKey(plan.messages); + event.minSourceCoreId = plan.messages.sourceCoreIds.front(); + event.minTargetCoreId = plan.messages.targetCoreIds.front(); + + for (size_t index = 1, end = plan.messages.size(); index < end; ++index) { + event.minChannelId = std::min(event.minChannelId, plan.messages.channelIds[index]); + event.minSourceCoreId = std::min(event.minSourceCoreId, plan.messages.sourceCoreIds[index]); + event.minTargetCoreId = std::min(event.minTargetCoreId, plan.messages.targetCoreIds[index]); + } + + return event; +} + +SmallVector +collectGloballyOrderedScalarFanoutEvents(const ScalarSourceFanoutPlan& plan) { + SmallVector events; + events.reserve(plan.receivePlans.size()); + + for (auto [index, receivePlan] : llvm::enumerate(plan.receivePlans)) + if (!receivePlan.messages.empty()) + events.push_back(makeGloballyOrderedScalarFanoutEvent(index, receivePlan)); + + llvm::sort(events, [](const GloballyOrderedScalarFanoutEvent& lhs, + const GloballyOrderedScalarFanoutEvent& rhs) { + if (lhs.orderKey != rhs.orderKey) + return lhs.orderKey < rhs.orderKey; + if (lhs.minChannelId != rhs.minChannelId) + return lhs.minChannelId < rhs.minChannelId; + if (lhs.minSourceCoreId != rhs.minSourceCoreId) + return lhs.minSourceCoreId < rhs.minSourceCoreId; + if (lhs.minTargetCoreId != rhs.minTargetCoreId) + return lhs.minTargetCoreId < rhs.minTargetCoreId; + return lhs.receivePlanIndex < rhs.receivePlanIndex; + }); + + return events; +} + +LogicalResult emitGloballyOrderedScalarFanoutSend(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const ScalarSourceReceivePlan& plan, + Location loc) { + if (plan.projectedDescriptor) + return appendProjectedScalarSendLoop(state, sourceClass, payload, *plan.projectedDescriptor, plan.messages, loc); + + return appendSend(state, sourceClass, payload, plan.messages, loc); +} + +bool isMaterializedBlockingCommunication(Operation& op) { + return isa(&op) || op.hasAttr(kRaptorMinChannelIdAttr) + || op.hasAttr(kRaptorCommOrderAttr); +} + +bool payloadIsAvailableOnlyAfterPriorCommunication(Value payload, MaterializedClass& sourceClass) { + Operation* lowerBound = getPayloadDefiningOpInClassBlock(payload, sourceClass); + if (!lowerBound) + return false; + + bool sawPriorCommunication = false; + Operation* terminator = sourceClass.body->getTerminator(); + for (Operation& op : *sourceClass.body) { + if (&op == terminator) + break; + + if (&op == lowerBound) + return sawPriorCommunication || isMaterializedBlockingCommunication(op); + + if (isMaterializedBlockingCommunication(op)) + sawPriorCommunication = true; + } + + return sawPriorCommunication; +} + +bool shouldPlaceMatchingScalarFanoutReceiveLate(MaterializedClass& sourceClass, + Value payload, + const MessageVector& messages) { + if (payloadIsAvailableOnlyAfterPriorCommunication(payload, sourceClass)) + return true; + + for (size_t index = 0, end = messages.size(); index < end; ++index) + if (shouldDelayScalarSendUntilAfterReceives( + payload, messages.sourceCoreIds[index], messages.targetCoreIds[index])) + return true; + return false; +} + +LogicalResult emitGloballyOrderedScalarSourceFanout(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + Value payload, + const ScalarSourceFanoutPlan& plan, + Location loc) { + SmallVector events = collectGloballyOrderedScalarFanoutEvents(plan); + + for (const GloballyOrderedScalarFanoutEvent& event : events) { + const ScalarSourceReceivePlan& planEntry = plan.receivePlans[event.receivePlanIndex]; + MaterializedClass& targetClass = state.classes[planEntry.targetClass]; + + if (failed(emitGloballyOrderedScalarFanoutSend(state, sourceClass, payload, planEntry, loc))) + return failure(); + + if (!targetClass.isBatch && !planEntry.projectedExtractOp) { + recordPendingScalarReceive(state, targetClass.id, keys, planEntry.receiveType, planEntry.messages, loc); + continue; + } + + bool lateReceive = shouldPlaceMatchingScalarFanoutReceiveLate(sourceClass, payload, planEntry.messages); + Value received = appendReceive(state, targetClass, planEntry.receiveType, planEntry.messages, loc, lateReceive); + + if (planEntry.projectedExtractOp) { + state.projectedExtractReplacements[planEntry.projectedExtractOp][planEntry.targetClass] = + ProjectedExtractReplacement {received, planEntry.projectedLayout}; + continue; + } + + for (ProducerKey key : keys) + state.availableValues.record(key, targetClass.id, received); + } + + return success(); +} + +LogicalResult emitScalarSourceCommunication( + MaterializerState& state, MaterializedClass& sourceClass, ArrayRef keys, Value payload, Location loc) { + assert(!sourceClass.isBatch && "scalar-source communication expects a scalar source class"); + + for (ProducerKey key : keys) + state.availableValues.record(key, sourceClass.id, payload); + + SmallVector destinationClasses = collectDestinationClassesForKeys(state, keys); + auto fanoutPlan = buildScalarSourceFanoutPlan(state, sourceClass, keys, destinationClasses, payload); + if (failed(fanoutPlan)) + return failure(); + if (pimMaterializeScalarFanoutGlobalOrder) + return emitGloballyOrderedScalarSourceFanout(state, sourceClass, keys, payload, *fanoutPlan, loc); + + if (failed(emitScalarSourceFanoutSends(state, sourceClass, payload, *fanoutPlan, loc))) + return failure(); + + for (const ScalarSourceReceivePlan& plan : fanoutPlan->receivePlans) { + MaterializedClass& targetClass = state.classes[plan.targetClass]; + + Value received = appendReceive(state, targetClass, plan.receiveType, plan.messages, loc); + + if (plan.projectedExtractOp) { + state.projectedExtractReplacements[plan.projectedExtractOp][plan.targetClass] = + ProjectedExtractReplacement {received, plan.projectedLayout}; + continue; + } + + for (ProducerKey key : keys) + state.availableValues.record(key, targetClass.id, received); + } + + return success(); +} + +FailureOr emitOrderedBatchToBatchCommunication(MaterializerState& state, + MaterializedClass& sourceClass, + MaterializedClass& targetClass, + Value payload, + const MessageVector& messages, + Location loc) { + assert(sourceClass.isBatch && targetClass.isBatch && "ordered batch communication expects two batch classes"); + if (failed(messages.verify(sourceClass.op))) + return failure(); + + auto payloadType = dyn_cast(payload.getType()); + if (!payloadType || !payloadType.hasStaticShape()) + return sourceClass.op->emitError("ordered batch communication expects a static ranked tensor payload"); + + auto makeEmpty = [&](MaterializedClass& materializedClass) -> Value { + return tensor::EmptyOp::create( + state.rewriter, loc, payloadType.getShape(), payloadType.getElementType()) + .getResult(); + }; + + setInsertionPointForEarlyCommunication(state, sourceClass); + Value sendChannelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc); + Value sendSourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc); + Value sendTargetCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.targetCoreIds, loc); + Value sendEarlyCond = arith::CmpIOp::create( + state.rewriter, + loc, + arith::CmpIPredicate::sle, + sendSourceCoreId, + sendTargetCoreId) + .getResult(); + auto earlySendIf = scf::IfOp::create(state.rewriter, loc, TypeRange {}, sendEarlyCond, /*withElseRegion=*/false); + state.rewriter.setInsertionPoint(earlySendIf.thenBlock()->getTerminator()); + auto earlySend = SpatChannelSendOp::create( + state.rewriter, loc, sendChannelId, sendSourceCoreId, sendTargetCoreId, payload); + markScalarCommunication( + earlySend.getOperation(), getMinimumChannelId(messages.channelIds), "emitOrderedBatchToBatchCommunication.earlySend"); + + setInsertionPointForLateCommunication(state, sourceClass); + Value sendLateCond = arith::CmpIOp::create( + state.rewriter, + loc, + arith::CmpIPredicate::sgt, + sendSourceCoreId, + sendTargetCoreId) + .getResult(); + auto lateSendIf = scf::IfOp::create(state.rewriter, loc, TypeRange {}, sendLateCond, /*withElseRegion=*/false); + rememberLateCommunicationOp(state, sourceClass, lateSendIf.getOperation()); + state.rewriter.setInsertionPoint(lateSendIf.thenBlock()->getTerminator()); + auto lateSend = SpatChannelSendOp::create( + state.rewriter, loc, sendChannelId, sendSourceCoreId, sendTargetCoreId, payload); + markScalarCommunication( + lateSend.getOperation(), getMinimumChannelId(messages.channelIds), "emitOrderedBatchToBatchCommunication.lateSend"); + + setInsertionPointForEarlyCommunication(state, targetClass); + Value recvChannelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc); + Value recvSourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc); + Value recvTargetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc); + Value recvEarlyCond = arith::CmpIOp::create( + state.rewriter, + loc, + arith::CmpIPredicate::sle, + recvSourceCoreId, + recvTargetCoreId) + .getResult(); + auto earlyReceiveIf = scf::IfOp::create( + state.rewriter, loc, TypeRange {payload.getType()}, recvEarlyCond, /*withElseRegion=*/true); + Operation* earlyThenYield = earlyReceiveIf.thenBlock()->getTerminator(); + state.rewriter.setInsertionPoint(earlyThenYield); + auto earlyReceive = SpatChannelReceiveOp::create( + state.rewriter, loc, payload.getType(), recvChannelId, recvSourceCoreId, recvTargetCoreId); + markScalarCommunication( + earlyReceive.getOperation(), getMinimumChannelId(messages.channelIds), "emitOrderedBatchToBatchCommunication.earlyReceive"); + Value earlyReceived = earlyReceive.getOutput(); + state.rewriter.modifyOpInPlace(earlyThenYield, [&] { earlyThenYield->setOperands(ValueRange {earlyReceived}); }); + Operation* earlyElseYield = earlyReceiveIf.elseBlock()->getTerminator(); + state.rewriter.setInsertionPoint(earlyElseYield); + Value empty = makeEmpty(targetClass); + state.rewriter.modifyOpInPlace(earlyElseYield, [&] { earlyElseYield->setOperands(ValueRange {empty}); }); + + setInsertionPointForLateCommunication(state, targetClass); + Value recvLateCond = arith::CmpIOp::create( + state.rewriter, + loc, + arith::CmpIPredicate::sgt, + recvSourceCoreId, + recvTargetCoreId) + .getResult(); + auto lateReceiveIf = scf::IfOp::create( + state.rewriter, loc, TypeRange {payload.getType()}, recvLateCond, /*withElseRegion=*/true); + rememberLateCommunicationOp(state, targetClass, lateReceiveIf.getOperation()); + Operation* lateThenYield = lateReceiveIf.thenBlock()->getTerminator(); + state.rewriter.setInsertionPoint(lateThenYield); + auto lateReceive = SpatChannelReceiveOp::create( + state.rewriter, loc, payload.getType(), recvChannelId, recvSourceCoreId, recvTargetCoreId); + markScalarCommunication( + lateReceive.getOperation(), getMinimumChannelId(messages.channelIds), "emitOrderedBatchToBatchCommunication.lateReceive"); + Value lateReceived = lateReceive.getOutput(); + state.rewriter.modifyOpInPlace(lateThenYield, [&] { lateThenYield->setOperands(ValueRange {lateReceived}); }); + Operation* lateElseYield = lateReceiveIf.elseBlock()->getTerminator(); + state.rewriter.modifyOpInPlace( + lateElseYield, [&] { lateElseYield->setOperands(ValueRange {earlyReceiveIf.getResult(0)}); }); + + return lateReceiveIf.getResult(0); +} + +LogicalResult emitClassToClassCommunication(MaterializerState& state, + MaterializedClass& sourceClass, + MaterializedClass& targetClass, + ArrayRef keys, + Value payload, + Location loc) { + if (sourceClass.id == targetClass.id) { + for (ProducerKey key : keys) + state.availableValues.record(key, targetClass.id, payload); + return success(); + } + + if (!sourceClass.isBatch) + return sourceClass.op->emitError("scalar-source communication must be emitted through the scalar fanout planner"); + + if (!targetClass.isBatch) { + MessageVector messages; + messages.channelIds.reserve(sourceClass.cpus.size()); + messages.sourceCoreIds.reserve(sourceClass.cpus.size()); + messages.targetCoreIds.reserve(sourceClass.cpus.size()); + + auto targetCpu = getCheckedCoreId(targetClass.op, targetClass.cpus.front(), "batch-to-scalar target core id"); + if (failed(targetCpu)) + return failure(); + for (CpuId sourceCpu : sourceClass.cpus) { + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch-to-scalar source core id"); + if (failed(checkedSourceCpu)) + return failure(); + messages.append(state.nextChannelId++, *checkedSourceCpu, *targetCpu); + } + + if (failed(appendSend(state, sourceClass, payload, messages, loc))) + return failure(); + return registerLazyPackedScalarReceives(state, + sourceClass, + targetClass, + keys, + payload.getType(), + messages.channelIds, + messages.sourceCoreIds, + messages.targetCoreIds); + } + + if (sourceClass.cpus.size() != targetClass.cpus.size()) + return sourceClass.op->emitError( + "cannot materialize batch communication between equivalence classes of different sizes"); + + MessageVector messages; + messages.channelIds.reserve(sourceClass.cpus.size()); + messages.sourceCoreIds.reserve(sourceClass.cpus.size()); + messages.targetCoreIds.reserve(targetClass.cpus.size()); + + for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch source core id"); + if (failed(checkedSourceCpu)) + return failure(); + auto checkedTargetCpu = getCheckedCoreId(targetClass.op, targetClass.cpus[lane], "batch target core id"); + if (failed(checkedTargetCpu)) + return failure(); + messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); + } + + FailureOr received = + emitOrderedBatchToBatchCommunication(state, sourceClass, targetClass, payload, messages, loc); + if (failed(received)) + return failure(); + + for (ProducerKey key : keys) + state.availableValues.record(key, targetClass.id, *received); + + return success(); +} + +LogicalResult +setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Value originalOutput, Value payload) { + auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput); + if (resultIt == sourceClass.hostOutputToResultIndex.end()) + return sourceClass.op->emitError("missing host result slot for materialized output") + << " ownerKind=" << (sourceClass.isBatch ? "batch" : "scalar") + << " hostOutputs=" << sourceClass.hostOutputs.size() + << " originalDef=" << (originalOutput.getDefiningOp() ? originalOutput.getDefiningOp()->getName().getStringRef() + : StringRef("")); + + unsigned resultIndex = resultIt->second; + if (payload.getType() != originalOutput.getType()) + return sourceClass.op->emitError("cannot set host output from fragment payload without projection") + << " payloadType=" << payload.getType() << " outputType=" << originalOutput.getType(); + + if (!sourceClass.isBatch) { + auto yieldOp = dyn_cast(sourceClass.body->getTerminator()); + if (!yieldOp) + return sourceClass.op->emitError("expected spat.yield terminator in materialized compute"); + if (resultIndex >= yieldOp.getNumOperands()) + return sourceClass.op->emitError("host result index out of range for materialized compute"); + + state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(resultIndex, payload); }); + state.hostReplacements[originalOutput] = sourceClass.op->getResult(resultIndex); + return success(); + } + + auto batch = cast(sourceClass.op); + auto inParallelOp = dyn_cast(sourceClass.body->getTerminator()); + if (!inParallelOp) + return sourceClass.op->emitError("expected spat.in_parallel terminator in materialized compute_batch"); + + auto payloadType = dyn_cast(payload.getType()); + if (!payloadType || !payloadType.hasStaticShape()) + return sourceClass.op->emitError("host-facing compute_batch payload must be a static ranked tensor"); + + auto laneArg = batch.getLaneArgument(); + if (!laneArg) + return batch.emitOpError("expected compute_batch lane block argument while materializing batch output"); + + auto outputArg = batch.getOutputArgument(resultIndex); + if (!outputArg) + return batch.emitOpError("expected compute_batch output block argument while materializing batch output"); + + state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); + + createDim0ParallelInsertSlice(state, payload.getLoc(), payload, *outputArg, *laneArg); + state.hostReplacements[originalOutput] = sourceClass.op->getResult(resultIndex); + return success(); +} + +FailureOr +getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex); + +LogicalResult emitProjectedBatchHostOutput(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + Value originalOutput, + Value payload, + Location loc) { + if (!sourceClass.isBatch) + return sourceClass.op->emitError("projected batch host publication expects a batch owner class"); + auto batch = cast(sourceClass.op); + + auto ownerIt = sourceClass.hostOutputToResultIndex.find(originalOutput); + if (ownerIt == sourceClass.hostOutputToResultIndex.end()) + return sourceClass.op->emitError("missing host result slot for projected batch output"); + + auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); + auto originalResult = dyn_cast(originalOutput); + if (!sourceBatch || sourceBatch.getNumResults() == 0 || !originalResult) + return sourceClass.op->emitError("projected batch host publication expects a resultful compute_batch output"); + + FailureOr projection = + getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); + if (failed(projection)) + return sourceBatch.emitOpError("failed to recover batch host projection for publication"); + + auto sourceLaneArg = sourceBatch.getLaneArgument(); + if (!sourceLaneArg) + return sourceBatch.emitOpError("missing source compute_batch lane argument for host projection"); + + // The projection coordinates are part of the source batch publication. + // Build any affine/index helper ops in the source batch body, not at the + // caller's current insertion point. Otherwise a scalar host-owner body may + // accidentally capture the source scheduled_compute_batch lane argument. + OpBuilder::InsertionGuard projectionGuard(state.rewriter); + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + + FailureOr projectionLaneValue = createProjectionLaneValueForKeys(state, sourceClass, keys, loc); + if (failed(projectionLaneValue)) + return failure(); + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(projection->getMixedOffsets().size()); + sizes.reserve(projection->getMixedSizes().size()); + strides.reserve(projection->getMixedStrides().size()); + + for (OpFoldResult offset : projection->getMixedOffsets()) { + FailureOr remapped = + remapProjectionIndexLike(state, sourceClass.op, offset, *sourceLaneArg, *projectionLaneValue, loc); + if (failed(remapped)) + return sourceClass.op->emitError("failed to remap projected batch host offsets"); + offsets.push_back(*remapped); + } + for (OpFoldResult size : projection->getMixedSizes()) { + FailureOr remapped = + remapProjectionIndexLike(state, sourceClass.op, size, *sourceLaneArg, *projectionLaneValue, loc); + if (failed(remapped)) + return sourceClass.op->emitError("failed to remap projected batch host sizes"); + sizes.push_back(*remapped); + } + for (OpFoldResult stride : projection->getMixedStrides()) { + FailureOr remapped = + remapProjectionIndexLike(state, sourceClass.op, stride, *sourceLaneArg, *projectionLaneValue, loc); + if (failed(remapped)) + return sourceClass.op->emitError("failed to remap projected batch host strides"); + strides.push_back(*remapped); + } + + auto inParallelOp = dyn_cast(sourceClass.body->getTerminator()); + if (!inParallelOp) + return sourceClass.op->emitError("expected spat.in_parallel terminator in materialized compute_batch"); + + auto outputArg = batch.getOutputArgument(ownerIt->second); + if (!outputArg) + return batch.emitOpError("missing host output block argument for projected batch publication"); + + state.hostReplacements[originalOutput] = sourceClass.op->getResult(ownerIt->second); + state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); + tensor::ParallelInsertSliceOp::create(state.rewriter, loc, payload, *outputArg, offsets, sizes, strides); + return success(); +} + +FailureOr evaluateProjectionIndexLike(OpFoldResult value, Value laneArg, uint32_t lane); + +FailureOr evaluateProjectionIndexLike(Value value, Value laneArg, uint32_t lane) { + if (value == laneArg) + return static_cast(lane); + + if (std::optional constant = matchConstantIndexValue(value)) + return *constant; + + auto affineApply = value.getDefiningOp(); + if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) + return failure(); + + SmallVector operands; + operands.reserve(affineApply.getMapOperands().size()); + for (Value operand : affineApply.getMapOperands()) { + FailureOr evaluated = evaluateProjectionIndexLike(operand, laneArg, lane); + if (failed(evaluated)) + return failure(); + operands.push_back(IntegerAttr::get(IndexType::get(value.getContext()), *evaluated)); + } + + SmallVector results; + if (failed(affineApply.getAffineMap().constantFold(operands, results)) || results.size() != 1) + return failure(); + + auto intAttr = dyn_cast(results.front()); + if (!intAttr) + return failure(); + return intAttr.getInt(); +} + +FailureOr evaluateProjectionIndexLike(OpFoldResult value, Value laneArg, uint32_t lane) { + if (auto attr = llvm::dyn_cast(value)) { + auto intAttr = dyn_cast(attr); + if (!intAttr) + return failure(); + return intAttr.getInt(); + } + return evaluateProjectionIndexLike(llvm::cast(value), laneArg, lane); +} + +FailureOr +getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex) { + auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); + if (!inParallel) + return failure(); + + auto firstOutputArg = batch.getOutputArgument(0); + if (!firstOutputArg) + return failure(); + + for (Operation& op : inParallel.getRegion().front()) { + auto insert = dyn_cast(&op); + if (!insert) + continue; + + auto outputArg = dyn_cast(insert.getDest()); + if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) + continue; + + unsigned candidateIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); + if (candidateIndex == resultIndex) + return insert; + } + + return failure(); +} + +FailureOr> +evaluateStaticProjectionIndices(ArrayRef values, Value laneArg, uint32_t lane) { + SmallVector evaluated; + evaluated.reserve(values.size()); + for (OpFoldResult value : values) { + FailureOr index = evaluateProjectionIndexLike(value, laneArg, lane); + if (failed(index)) + return failure(); + evaluated.push_back(*index); + } + return evaluated; +} + + +bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch, + const AffineProjectedInputSliceMatch& match, + ProducerKey producer, + uint32_t consumerLane) { + auto producerBatch = dyn_cast_or_null(producer.instance.op); + if (!producerBatch) + return true; + + FailureOr producerProjection = + getBatchResultProjectionInsert(producerBatch, producer.resultIndex); + if (failed(producerProjection)) + return true; + + std::optional producerLaneArg = producerBatch.getLaneArgument(); + std::optional consumerLaneArg = consumerBatch.getLaneArgument(); + if (!producerLaneArg || !consumerLaneArg) + return false; + + SmallVector consumerSizes(match.fragmentShape.begin(), match.fragmentShape.end()); + SmallVector loopIterationIndices(match.loops.size(), 0); + + const auto consumerSliceFitsOneProducerFragment = [&]() -> bool { + SmallVector consumerOffsets; + consumerOffsets.reserve(match.offsets.size()); + for (OpFoldResult offset : match.offsets) { + FailureOr evaluated = + evaluateProjectedOffsetValue(offset, *consumerLaneArg, consumerLane, match.loops, loopIterationIndices); + if (failed(evaluated)) + return false; + consumerOffsets.push_back(*evaluated); + } + + uint32_t producerLaneEnd = producer.instance.laneStart + producer.instance.laneCount; + for (uint32_t producerLane = producer.instance.laneStart; producerLane < producerLaneEnd; ++producerLane) { + FailureOr> producerOffsets = + evaluateStaticProjectionIndices(producerProjection->getMixedOffsets(), *producerLaneArg, producerLane); + FailureOr> producerSizes = + evaluateStaticProjectionIndices(producerProjection->getMixedSizes(), *producerLaneArg, producerLane); + FailureOr> producerStrides = + evaluateStaticProjectionIndices(producerProjection->getMixedStrides(), *producerLaneArg, producerLane); + if (failed(producerOffsets) || failed(producerSizes) || failed(producerStrides)) + return false; + if (!areAllUnitStrides(*producerStrides)) + return false; + if (isStaticSliceContainedIn(consumerOffsets, consumerSizes, *producerOffsets, *producerSizes)) + return true; + } + + return false; + }; + + if (match.loops.empty()) + return consumerSliceFitsOneProducerFragment(); + + const auto recurse = [&](auto&& self, size_t loopIndex) -> bool { + if (loopIndex == match.loops.size()) + return consumerSliceFitsOneProducerFragment(); + + for (int64_t iteration = 0; iteration < match.loops[loopIndex].tripCount; ++iteration) { + loopIterationIndices[loopIndex] = iteration; + if (!self(self, loopIndex + 1)) + return false; + } + return true; + }; + + return recurse(recurse, 0); +} + +LogicalResult insertProjectedBatchHostFragment(MaterializerState& state, + MaterializedClass& ownerClass, + Value originalOutput, + uint32_t lane, + Value payload) { + if (ownerClass.isBatch) + return ownerClass.op->emitError("projected batch host fallback expects a scalar owner class"); + + auto ownerIt = ownerClass.hostOutputToResultIndex.find(originalOutput); + if (ownerIt == ownerClass.hostOutputToResultIndex.end()) + return ownerClass.op->emitError("missing host result slot for projected batch host fragment"); + + auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); + auto originalResult = dyn_cast(originalOutput); + if (!sourceBatch || sourceBatch.getNumResults() == 0 || !originalResult) + return ownerClass.op->emitError("projected batch host fallback expects a resultful compute_batch output"); + + FailureOr projection = + getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); + if (failed(projection)) + return sourceBatch.emitOpError("failed to recover batch host projection for materialization"); + + auto laneArg = sourceBatch.getLaneArgument(); + if (!laneArg) + return sourceBatch.emitOpError("missing compute_batch lane argument for host projection"); + + FailureOr> offsets = + evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, lane); + FailureOr> sizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, lane); + FailureOr> strides = + evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, lane); + if (failed(offsets) || failed(sizes) || failed(strides)) + return ownerClass.op->emitError("failed to evaluate batch host projection coordinates"); + + auto yieldOp = dyn_cast(ownerClass.body->getTerminator()); + if (!yieldOp) + return ownerClass.op->emitError("expected spat.yield terminator in scalar host owner"); + + unsigned hostResultIndex = ownerIt->second; + if (hostResultIndex >= yieldOp.getNumOperands()) + return ownerClass.op->emitError("host result index out of range for projected batch host fragment"); + if (yieldOp.getOperand(hostResultIndex).getType() != originalOutput.getType()) + return ownerClass.op->emitError("projected batch host fragment expected a full host accumulator tensor") + << " accumulatorType=" << yieldOp.getOperand(hostResultIndex).getType() + << " outputType=" << originalOutput.getType(); + + state.rewriter.setInsertionPoint(yieldOp); + Value updated = tensor::InsertSliceOp::create(state.rewriter, + payload.getLoc(), + payload, + yieldOp.getOperand(hostResultIndex), + ValueRange {}, + ValueRange {}, + ValueRange {}, + *offsets, + *sizes, + *strides) + .getResult(); + state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(hostResultIndex, updated); }); + state.hostReplacements[originalOutput] = ownerClass.op->getResult(hostResultIndex); + return success(); +} + + +LogicalResult emitProjectedBatchHostReceiveInsertLoop(MaterializerState& state, + MaterializedClass& ownerClass, + Value originalOutput, + ArrayRef keys, + RankedTensorType fragmentType, + const MessageVector& messages, + Location loc) { + if (ownerClass.isBatch) + return ownerClass.op->emitError("projected batch host receive loop expects a scalar owner class"); + if (keys.empty()) + return success(); + if (keys.size() != messages.size()) + return ownerClass.op->emitError("projected batch host receive loop message metadata is inconsistent"); + + auto ownerIt = ownerClass.hostOutputToResultIndex.find(originalOutput); + if (ownerIt == ownerClass.hostOutputToResultIndex.end()) + return ownerClass.op->emitError("missing host result slot for projected batch host receive loop"); + + auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); + auto originalResult = dyn_cast(originalOutput); + if (!sourceBatch || sourceBatch.getNumResults() == 0 || !originalResult) + return ownerClass.op->emitError("projected batch host receive loop expects a resultful compute_batch output"); + + FailureOr projection = + getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); + if (failed(projection)) + return sourceBatch.emitOpError("failed to recover batch host projection for receive loop"); + + auto laneArg = sourceBatch.getLaneArgument(); + if (!laneArg) + return sourceBatch.emitOpError("missing compute_batch lane argument for projected host receive loop"); + + auto yieldOp = dyn_cast(ownerClass.body->getTerminator()); + if (!yieldOp) + return ownerClass.op->emitError("expected spat.yield terminator in scalar host owner"); + + unsigned hostResultIndex = ownerIt->second; + if (hostResultIndex >= yieldOp.getNumOperands()) + return ownerClass.op->emitError("host result index out of range for projected batch host receive loop"); + if (yieldOp.getOperand(hostResultIndex).getType() != originalOutput.getType()) + return ownerClass.op->emitError("projected batch host receive loop expected a full host accumulator tensor") + << " accumulatorType=" << yieldOp.getOperand(hostResultIndex).getType() + << " outputType=" << originalOutput.getType(); + + unsigned rank = projection->getMixedOffsets().size(); + SmallVector, 4> offsetsByDim(rank); + SmallVector, 4> sizesByDim(rank); + SmallVector, 4> stridesByDim(rank); + for (ProducerKey key : keys) { + if (key.instance.op != originalOutput.getDefiningOp() || key.resultIndex != originalResult.getResultNumber() + || key.instance.laneCount != 1) + return ownerClass.op->emitError("projected batch host receive loop expects one-lane fragments from one output"); + + FailureOr> offsets = + evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, key.instance.laneStart); + FailureOr> sizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, key.instance.laneStart); + FailureOr> strides = + evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, key.instance.laneStart); + if (failed(offsets) || failed(sizes) || failed(strides)) + return ownerClass.op->emitError("failed to evaluate projected batch host receive loop coordinates"); + if (offsets->size() != rank || sizes->size() != rank || strides->size() != rank) + return ownerClass.op->emitError("projected batch host receive loop coordinate rank mismatch"); + + for (unsigned dim = 0; dim < rank; ++dim) { + offsetsByDim[dim].push_back((*offsets)[dim]); + sizesByDim[dim].push_back((*sizes)[dim]); + stridesByDim[dim].push_back((*strides)[dim]); + } + } + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, static_cast(keys.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, 1); + + state.rewriter.setInsertionPoint(yieldOp); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {yieldOp.getOperand(hostResultIndex)}, + [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value channelId = createIndexedChannelId(state, ownerClass.op, messages, flatIndex, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, ownerClass.op, messages, flatIndex, loc); + Value targetCoreId = createIndexedTargetCoreId(state, ownerClass.op, messages, flatIndex, loc); + Value fragment = SpatChannelReceiveOp::create( + state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(rank); + sizes.reserve(rank); + strides.reserve(rank); + for (unsigned dim = 0; dim < rank; ++dim) { + offsets.push_back(createIndexedOrStaticIndex(state, ownerClass.op, offsetsByDim[dim], flatIndex, loc)); + sizes.push_back(createIndexedOrStaticIndex(state, ownerClass.op, sizesByDim[dim], flatIndex, loc)); + strides.push_back(createIndexedOrStaticIndex(state, ownerClass.op, stridesByDim[dim], flatIndex, loc)); + } + + Value updated = tensor::InsertSliceOp::create(state.rewriter, loc, fragment, iterArgs.front(), offsets, sizes, strides) + .getResult(); + yielded.push_back(updated); + return success(); + }); + if (failed(loop)) + return failure(); + markScalarCommunication( + loop->loop.getOperation(), getMinimumChannelId(messages.channelIds), "emitProjectedBatchHostReceiveInsertLoop"); + + state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(hostResultIndex, loop->results.front()); }); + state.hostReplacements[originalOutput] = ownerClass.op->getResult(hostResultIndex); + return success(); +} + +std::optional tryEmitProjectedBatchHostReceiveInsertLoop(MaterializerState& state, + MaterializedClass& ownerClass, + Value originalOutput, + ArrayRef keys, + Location loc) { + if (keys.empty()) + return success(); + + WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(keys.front(), ownerClass.id); + ArrayRef runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey); + for (size_t runIndex : runIndices) { + PackedScalarRunValue& run = state.availableValues.getPackedRun(runIndex); + if (run.kind != PackedScalarRunKind::DeferredReceive) + continue; + SmallVector runKeys = flattenPackedScalarRunKeys(run); + if (!llvm::equal(runKeys, keys)) + continue; + return emitProjectedBatchHostReceiveInsertLoop( + state, ownerClass, originalOutput, runKeys, run.fragmentType, run.messages, loc); + } + + return std::nullopt; +} + +FailureOr getLeadingPackedFragmentType(Operation* anchor, Value payload, size_t fragmentCount) { + auto payloadType = dyn_cast(payload.getType()); + if (!payloadType || !payloadType.hasStaticShape() || payloadType.getRank() == 0) + return failure(); + if (payloadType.getDimSize(0) != static_cast(fragmentCount)) + return failure(); + + SmallVector fragmentShape(payloadType.getShape().begin(), payloadType.getShape().end()); + fragmentShape[0] = 1; + return RankedTensorType::get(fragmentShape, payloadType.getElementType(), payloadType.getEncoding()); +} + +LogicalResult emitScalarPackedProjectedHostSendLoop(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + RankedTensorType fragmentType, + const MessageVector& messages, + Location loc) { + assert(!sourceClass.isBatch && "packed projected host send loop expects a scalar source"); + assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); + + auto payloadType = dyn_cast(payload.getType()); + if (!payloadType || !payloadType.hasStaticShape() || payloadType.getRank() == 0) + return sourceClass.op->emitError("packed projected host send loop expects a static ranked payload"); + + setInsertionPointForScalarCommunication(state, sourceClass, getMinimumChannelId(messages.channelIds)); + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); + Value upperBound = + getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(messages.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); + + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {}, + [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(payloadType.getRank()); + sizes.reserve(payloadType.getRank()); + strides.reserve(payloadType.getRank()); + offsets.push_back(index); + sizes.push_back(state.rewriter.getIndexAttr(1)); + strides.push_back(state.rewriter.getIndexAttr(1)); + for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) { + offsets.push_back(state.rewriter.getIndexAttr(0)); + sizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim))); + strides.push_back(state.rewriter.getIndexAttr(1)); + } + + Value fragment = tensor::ExtractSliceOp::create( + state.rewriter, loc, fragmentType, payload, offsets, sizes, strides) + .getResult(); + Value channelId = createIndexedChannelId(state, sourceClass.op, messages, index, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, messages, index, loc); + Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, messages, index, loc); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, fragment); + return success(); + }); + if (failed(loop)) + return failure(); + markScalarCommunication( + loop->loop.getOperation(), getMinimumChannelId(messages.channelIds), "emitScalarPackedProjectedHostSendLoop"); + return success(); +} + +LogicalResult emitScalarPackedProjectedHostLocalInsertLoop(MaterializerState& state, + MaterializedClass& ownerClass, + ArrayRef keys, + Value payload, + Value originalOutput, + RankedTensorType fragmentType, + Location loc) { + if (ownerClass.isBatch) + return ownerClass.op->emitError("packed projected host local insert loop expects a scalar owner class"); + if (keys.empty()) + return success(); + + auto payloadType = dyn_cast(payload.getType()); + if (!payloadType || !payloadType.hasStaticShape() || payloadType.getRank() == 0) + return ownerClass.op->emitError("packed projected host local insert loop expects a static ranked payload"); + if (payloadType.getDimSize(0) != static_cast(keys.size())) + return ownerClass.op->emitError("packed projected host local insert loop payload/key count mismatch"); + + auto ownerIt = ownerClass.hostOutputToResultIndex.find(originalOutput); + if (ownerIt == ownerClass.hostOutputToResultIndex.end()) + return ownerClass.op->emitError("missing host result slot for packed projected host local insert loop"); + + auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); + auto originalResult = dyn_cast(originalOutput); + if (!sourceBatch || sourceBatch.getNumResults() == 0 || !originalResult) + return ownerClass.op->emitError("packed projected host local insert loop expects a resultful compute_batch output"); + + FailureOr projection = + getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); + if (failed(projection)) + return sourceBatch.emitOpError("failed to recover batch host projection for local insert loop"); + + auto laneArg = sourceBatch.getLaneArgument(); + if (!laneArg) + return sourceBatch.emitOpError("missing compute_batch lane argument for packed projected host local insert loop"); + + auto yieldOp = dyn_cast(ownerClass.body->getTerminator()); + if (!yieldOp) + return ownerClass.op->emitError("expected spat.yield terminator in scalar host owner"); + + unsigned hostResultIndex = ownerIt->second; + if (hostResultIndex >= yieldOp.getNumOperands()) + return ownerClass.op->emitError("host result index out of range for packed projected host local insert loop"); + if (yieldOp.getOperand(hostResultIndex).getType() != originalOutput.getType()) + return ownerClass.op->emitError("packed projected host local insert loop expected a full host accumulator tensor") + << " accumulatorType=" << yieldOp.getOperand(hostResultIndex).getType() + << " outputType=" << originalOutput.getType(); + + unsigned rank = projection->getMixedOffsets().size(); + SmallVector, 4> offsetsByDim(rank); + SmallVector, 4> sizesByDim(rank); + SmallVector, 4> stridesByDim(rank); + for (ProducerKey key : keys) { + if (key.instance.op != originalOutput.getDefiningOp() || key.resultIndex != originalResult.getResultNumber() + || key.instance.laneCount != 1) + return ownerClass.op->emitError("packed projected host local insert loop expects one-lane fragments from one output"); + + FailureOr> offsets = + evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, key.instance.laneStart); + FailureOr> sizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, key.instance.laneStart); + FailureOr> strides = + evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, key.instance.laneStart); + if (failed(offsets) || failed(sizes) || failed(strides)) + return ownerClass.op->emitError("failed to evaluate packed projected host local insert loop coordinates"); + if (offsets->size() != rank || sizes->size() != rank || strides->size() != rank) + return ownerClass.op->emitError("packed projected host local insert loop coordinate rank mismatch"); + + for (unsigned dim = 0; dim < rank; ++dim) { + offsetsByDim[dim].push_back((*offsets)[dim]); + sizesByDim[dim].push_back((*sizes)[dim]); + stridesByDim[dim].push_back((*strides)[dim]); + } + } + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, static_cast(keys.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, ownerClass.op, 1); + + state.rewriter.setInsertionPoint(yieldOp); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {yieldOp.getOperand(hostResultIndex)}, + [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + SmallVector extractOffsets; + SmallVector extractSizes; + SmallVector extractStrides; + extractOffsets.reserve(payloadType.getRank()); + extractSizes.reserve(payloadType.getRank()); + extractStrides.reserve(payloadType.getRank()); + extractOffsets.push_back(flatIndex); + extractSizes.push_back(state.rewriter.getIndexAttr(1)); + extractStrides.push_back(state.rewriter.getIndexAttr(1)); + for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) { + extractOffsets.push_back(state.rewriter.getIndexAttr(0)); + extractSizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim))); + extractStrides.push_back(state.rewriter.getIndexAttr(1)); + } + + Value fragment = tensor::ExtractSliceOp::create( + state.rewriter, loc, fragmentType, payload, extractOffsets, extractSizes, extractStrides) + .getResult(); + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(rank); + sizes.reserve(rank); + strides.reserve(rank); + for (unsigned dim = 0; dim < rank; ++dim) { + offsets.push_back(createIndexedOrStaticIndex(state, ownerClass.op, offsetsByDim[dim], flatIndex, loc)); + sizes.push_back(createIndexedOrStaticIndex(state, ownerClass.op, sizesByDim[dim], flatIndex, loc)); + strides.push_back(createIndexedOrStaticIndex(state, ownerClass.op, stridesByDim[dim], flatIndex, loc)); + } + + Value updated = tensor::InsertSliceOp::create(state.rewriter, loc, fragment, iterArgs.front(), offsets, sizes, strides) + .getResult(); + yielded.push_back(updated); + return success(); + }); + if (failed(loop)) + return failure(); + + state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(hostResultIndex, loop->results.front()); }); + state.hostReplacements[originalOutput] = ownerClass.op->getResult(hostResultIndex); + return success(); +} + +std::optional tryEmitScalarPackedProjectedHostPublication(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + Value payload, + Value originalOutput, + Location loc) { + if (sourceClass.isBatch || keys.size() <= 1) + return std::nullopt; + + auto ownerIt = state.hostOutputOwners.find(originalOutput); + if (ownerIt == state.hostOutputOwners.end()) + return sourceClass.op->emitError("missing host owner for projected batch output"); + + MaterializedClass& ownerClass = state.classes[ownerIt->second]; + if (ownerClass.isBatch) + return ownerClass.op->emitError( + "projected batch host output reached a batch owner without an explicit batch publication path"); + FailureOr fragmentType = getLeadingPackedFragmentType(sourceClass.op, payload, keys.size()); + if (failed(fragmentType)) + return std::nullopt; + + if (ownerClass.id == sourceClass.id) + return emitScalarPackedProjectedHostLocalInsertLoop( + state, ownerClass, keys, payload, originalOutput, *fragmentType, loc); + + auto sourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "projected host source core id"); + auto targetCpu = getCheckedCoreId(ownerClass.op, ownerClass.cpus.front(), "projected host target core id"); + if (failed(sourceCpu) || failed(targetCpu)) + return failure(); + + MessageVector messages; + for ([[maybe_unused]] ProducerKey key : keys) + messages.append(state.nextChannelId++, *sourceCpu, *targetCpu); + + if (failed(messages.verify(sourceClass.op))) + return failure(); + + if (failed(emitScalarPackedProjectedHostSendLoop(state, sourceClass, payload, *fragmentType, messages, loc))) + return failure(); + + return emitProjectedBatchHostReceiveInsertLoop( + state, ownerClass, originalOutput, keys, *fragmentType, messages, loc); +} + +void appendPendingProjectedHostReceive(MaterializerState& state, + MaterializedClass& ownerClass, + Value originalOutput, + ProducerKey key, + RankedTensorType fragmentType, + const MessageVector& messages, + Location loc) { + assert(messages.size() == 1 && "pending projected host receive records one message at a time"); + for (PendingProjectedHostReceiveGroup& group : state.pendingProjectedHostReceives) { + if (group.originalOutput != originalOutput || group.ownerClassId != ownerClass.id || group.fragmentType != fragmentType) + continue; + group.keys.push_back(key); + group.messages.append(messages.channelIds, messages.sourceCoreIds, messages.targetCoreIds); + return; + } + + PendingProjectedHostReceiveGroup group { + originalOutput, + ownerClass.id, + fragmentType, + SmallVector{key}, + MessageVector{}, + loc + }; + group.messages.append(messages.channelIds, messages.sourceCoreIds, messages.targetCoreIds); + state.pendingProjectedHostReceives.push_back(std::move(group)); +} + +LogicalResult flushPendingProjectedHostReceives(MaterializerState& state) { + for (PendingProjectedHostReceiveGroup& group : state.pendingProjectedHostReceives) { + if (group.ownerClassId >= state.classes.size()) + return state.func.emitError("pending projected host receive has invalid owner class"); + MaterializedClass& ownerClass = state.classes[group.ownerClassId]; + if (failed(group.messages.verify(ownerClass.op))) + return failure(); + if (group.keys.empty()) + continue; + if (failed(emitProjectedBatchHostReceiveInsertLoop( + state, ownerClass, group.originalOutput, group.keys, group.fragmentType, group.messages, group.loc))) + return failure(); + } + state.pendingProjectedHostReceives.clear(); + return success(); +} + +LogicalResult emitProjectedBatchHostFragment(MaterializerState& state, + MaterializedClass& sourceClass, + ProducerKey key, + Value payload, + Value originalOutput, + Location loc) { + auto ownerIt = state.hostOutputOwners.find(originalOutput); + if (ownerIt == state.hostOutputOwners.end()) + return sourceClass.op->emitError("missing host owner for projected batch output"); + + MaterializedClass& ownerClass = state.classes[ownerIt->second]; + Value ownerPayload = payload; + if (sourceClass.id != ownerClass.id) { + if (ownerClass.isBatch) { + return ownerClass.op->emitError( + "projected batch host fragment reached a batch owner without an explicit batch publication path"); + } + + MessageVector messages; + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "projected host source core id"); + auto checkedTargetCpu = getCheckedCoreId(ownerClass.op, ownerClass.cpus.front(), "projected host target core id"); + if (failed(checkedTargetCpu)) + return failure(); + if (!sourceClass.isBatch) { + if (failed(checkedSourceCpu)) + return failure(); + messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); + + if (failed(appendSend(state, sourceClass, payload, messages, loc))) + return failure(); + + auto fragmentType = dyn_cast(payload.getType()); + if (!fragmentType) + return sourceClass.op->emitError("projected terminal batch host fragment expects ranked tensor payload"); + appendPendingProjectedHostReceive(state, ownerClass, originalOutput, key, fragmentType, messages, loc); + return success(); + } + else { + ComputeInstance scheduledInstance = getScheduledChunkForLogicalInstance(state, key.instance); + auto sourceCpuIt = state.schedule.computeToCpuMap.find(scheduledInstance); + if (sourceCpuIt == state.schedule.computeToCpuMap.end()) + return sourceClass.op->emitError("missing CPU assignment for projected batch host source"); + + auto localLaneIt = sourceClass.cpuToLane.find(sourceCpuIt->second); + if (localLaneIt == sourceClass.cpuToLane.end()) + return sourceClass.op->emitError("missing local batch lane for projected batch host source"); + + if (failed(checkedSourceCpu = getCheckedCoreId(sourceClass.op, + sourceCpuIt->second, + "projected host source core id"))) + return failure(); + messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); + + auto batch = cast(sourceClass.op); + auto laneArg = batch.getLaneArgument(); + if (!laneArg) + return batch.emitOpError("missing lane argument for projected batch host source"); + + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + Value localLane = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, localLaneIt->second); + Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.channelIds.front()); + Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.sourceCoreIds.front()); + Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.targetCoreIds.front()); + Value isSourceLane = arith::CmpIOp::create(state.rewriter, loc, arith::CmpIPredicate::eq, *laneArg, localLane); + auto ifOp = scf::IfOp::create(state.rewriter, loc, TypeRange {}, isSourceLane, /*withElseRegion=*/false); + state.rewriter.setInsertionPoint(ifOp.thenBlock()->getTerminator()); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); + ownerPayload = appendReceive(state, ownerClass, payload.getType(), messages, loc); + } + } + + return insertProjectedBatchHostFragment(state, ownerClass, originalOutput, key.instance.laneStart, ownerPayload); +} + +LogicalResult +emitHostCommunication(MaterializerState& state, MaterializedClass& sourceClass, Value payload, Value originalOutput) { + if (!hasLiveExternalUseCached(state, originalOutput)) + return success(); + + if (isProjectedTerminalBatchHostOutput(originalOutput, state.oldComputeOps)) + return sourceClass.op->emitError("cannot set projected terminal batch host output through the generic host path"); + + auto ownerIt = state.hostOutputOwners.find(originalOutput); + if (ownerIt == state.hostOutputOwners.end()) + return sourceClass.op->emitError("missing host owner for live external output"); + + MaterializedClass& ownerClass = state.classes[ownerIt->second]; + if (sourceClass.id == ownerClass.id) + return setHostOutputValue(state, ownerClass, originalOutput, payload); + + if (sourceClass.isBatch) + return sourceClass.op->emitError("batch host publication must be routed through a projection-aware or owning path"); + if (ownerClass.isBatch) + return ownerClass.op->emitError("generic host publication does not support batch host owners"); + + MessageVector messages; + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "host source core id"); + auto checkedTargetCpu = getCheckedCoreId(ownerClass.op, ownerClass.cpus.front(), "host target core id"); + if (failed(checkedSourceCpu) || failed(checkedTargetCpu)) + return failure(); + messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); + + if (failed(appendSend(state, sourceClass, payload, messages, payload.getLoc()))) + return failure(); + Value ownerPayload = appendReceive(state, ownerClass, payload.getType(), messages, payload.getLoc()); + return setHostOutputValue(state, ownerClass, originalOutput, ownerPayload); +} + +LogicalResult emitOutputFanout(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + Value payload, + Value originalOutput, + Location loc) { + if (keys.empty()) + return success(); + + if (!sourceClass.isBatch) { + if (failed(emitScalarSourceCommunication(state, sourceClass, keys, payload, loc))) + return failure(); + + if (isProjectedTerminalBatchHostOutput(originalOutput, state.oldComputeOps)) { + std::optional loopedHostPublication = + tryEmitScalarPackedProjectedHostPublication(state, sourceClass, keys, payload, originalOutput, loc); + if (loopedHostPublication) + return *loopedHostPublication; + + for (ProducerKey key : keys) { + if (key.instance.laneCount != 1) + return sourceClass.op->emitError("projected terminal batch host output expects one logical lane per fragment"); + if (failed(emitProjectedBatchHostFragment(state, sourceClass, key, payload, originalOutput, loc))) + return failure(); + } + return success(); + } + + return emitHostCommunication(state, sourceClass, payload, originalOutput); + } + + if (!haveSameDestinationClasses(state, keys)) + return sourceClass.op->emitError( + "cannot materialize batched output whose lanes have different destination equivalence classes"); + + if (auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp())) { + if (sourceBatch.getNumResults() != 0 && isTerminalHostBatchOutput(originalOutput, state.oldComputeOps)) { + for (ClassId destinationClass : getDestinationClasses(state, keys.front())) + if (!state.classes[destinationClass].isBatch) + return emitBatchToScalarDestinationDiagnostic(state, sourceClass, keys, originalOutput); + } + } + + for (ClassId destinationClass : getDestinationClasses(state, keys.front())) + if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) + return failure(); + + auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); + if (sourceBatch && sourceBatch.getNumResults() != 0 && hasLiveExternalUseCached(state, originalOutput)) { + if (sourceClass.hostOutputToResultIndex.contains(originalOutput)) { + if (failed(emitProjectedBatchHostOutput(state, sourceClass, keys, originalOutput, payload, loc))) + return failure(); + } + else { + auto ownerIt = state.hostOutputOwners.find(originalOutput); + if (ownerIt == state.hostOutputOwners.end()) + return sourceClass.op->emitError("missing host owner for projected batch output"); + + MaterializedClass& ownerClass = state.classes[ownerIt->second]; + if (ownerClass.isBatch) + return ownerClass.op->emitError( + "projected batch host output reached a batch owner without an explicit batch publication path"); + + if (sourceClass.id != ownerClass.id + && failed(emitClassToClassCommunication(state, sourceClass, ownerClass, keys, payload, loc))) + return failure(); + + std::optional loopedHostPublication = + tryEmitProjectedBatchHostReceiveInsertLoop(state, ownerClass, originalOutput, keys, loc); + if (loopedHostPublication) { + if (failed(*loopedHostPublication)) + return failure(); + } + else { + for (ProducerKey key : keys) { + if (key.instance.laneCount != 1) + return sourceClass.op->emitError("projected batch host output expects one logical lane per fragment"); + + std::optional ownerPayload = state.availableValues.lookup(state, key, ownerClass.id); + if (!ownerPayload) + return ownerClass.op->emitError("failed to recover projected batch host fragment after communication"); + + if (failed(insertProjectedBatchHostFragment( + state, ownerClass, originalOutput, key.instance.laneStart, *ownerPayload))) + return failure(); + } + } + } + } else if (failed(emitHostCommunication(state, sourceClass, payload, originalOutput))) { + return failure(); + } + + for (ProducerKey key : keys) + state.availableValues.record(key, sourceClass.id, payload); + + return success(); +} + +struct DirectWholeBatchFragment { + ProducerKey key; + Value fragment; +}; + +enum class WholeBatchFragmentSourceKind { + DeferredReceive, + DeferredLocalCompute, + PackedValue, + DirectValue +}; + +struct WholeBatchFragmentGroup { + WholeBatchFragmentSourceKind kind = WholeBatchFragmentSourceKind::DirectValue; + RankedTensorType fragmentType; + SmallVector outputOffsets; + MessageVector messages; + Operation* sourceOp = nullptr; + size_t resultIndex = 0; + SmallVector sourceLanes; + Value packed; + RankedTensorType slotPackedType; + SmallVector slotIndices; + SmallVector, 16> directFragments; + SmallVector redundantReceives; +}; + +enum class ProjectedWholeBatchFragmentSourceKind { + DeferredReceive, + PackedValue, + DirectValue +}; + +struct ProjectedWholeBatchDirectFragment { + Value fragment; + SmallVector offsets; + SmallVector sizes; + SmallVector strides; +}; + +struct ProjectedWholeBatchFragmentGroup { + ProjectedWholeBatchFragmentSourceKind kind = ProjectedWholeBatchFragmentSourceKind::DirectValue; + RankedTensorType fragmentType; + SmallVector, 4> offsetsByDim; + SmallVector, 4> sizesByDim; + SmallVector, 4> stridesByDim; + MessageVector messages; + SmallVector redundantOps; + Value packed; + RankedTensorType packedSourceType; + SmallVector packedIndices; + SmallVector directFragments; +}; + +struct WholeBatchAssemblyPlan { + RankedTensorType resultType; + int64_t rowsPerLane = 0; + uint32_t batchLaneCount = 0; + uint32_t coveredLaneCount = 0; + + SmallVector coveredLanes; + SmallVector packedRuns; + SmallVector directFragments; +}; + +bool wholeBatchLaneCovered(const WholeBatchAssemblyPlan& plan, uint32_t lane) { + return lane < plan.coveredLanes.size() && plan.coveredLanes[lane] != 0; +} + +bool wholeBatchRangeOverlaps(const WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) { + if (laneCount == 0) + return false; + if (laneStart >= plan.coveredLanes.size()) + return false; + + uint32_t laneEnd = std::min(laneStart + laneCount, plan.coveredLanes.size()); + for (uint32_t lane = laneStart; lane < laneEnd; ++lane) + if (plan.coveredLanes[lane] != 0) + return true; + return false; +} + +void recordWholeBatchCoverage(WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) { + assert(laneCount != 0 && "cannot cover an empty whole-batch range"); + assert(laneStart + laneCount <= plan.coveredLanes.size() && "whole-batch coverage out of bounds"); + + for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane) { + if (plan.coveredLanes[lane] != 0) + continue; + plan.coveredLanes[lane] = 1; + ++plan.coveredLaneCount; + } +} + +bool localLaneRangeOverlaps(ArrayRef covered, uint32_t laneStart, uint32_t laneCount) { + if (laneCount == 0) + return false; + if (laneStart >= covered.size()) + return false; + + uint32_t laneEnd = std::min(laneStart + laneCount, covered.size()); + for (uint32_t lane = laneStart; lane < laneEnd; ++lane) + if (covered[lane] != 0) + return true; + return false; +} + +void markLocalLaneRangeCovered(MutableArrayRef covered, uint32_t laneStart, uint32_t laneCount) { + assert(laneStart + laneCount <= covered.size() && "local coverage out of bounds"); + for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane) + covered[lane] = 1; +} + +LogicalResult +validateWholeBatchFragmentType(RankedTensorType resultType, RankedTensorType fragmentType, int64_t expectedRows) { + if (!fragmentType.hasStaticShape()) + return failure(); + if (fragmentType.getRank() != resultType.getRank()) + return failure(); + if (fragmentType.getDimSize(0) != expectedRows) + return failure(); + + for (int64_t dim = 1; dim < resultType.getRank(); ++dim) + if (fragmentType.getDimSize(dim) != resultType.getDimSize(dim)) + return failure(); + + return success(); +} + +// ----------------------------------------------------------------------------- +// Packed run tensor assembly helpers. +// ----------------------------------------------------------------------------- + +FailureOr insertFragmentIntoWholeBatch(MaterializerState& state, + MaterializedClass& targetClass, + Value fragment, + Value destination, + OpFoldResult firstOffset, + Location loc) { + return createDim0InsertSliceInClass(state, targetClass, loc, fragment, destination, firstOffset); +} + +FailureOr extractPackedSlotForIndex(MaterializerState& state, + MaterializedClass& targetClass, + Value packed, + RankedTensorType slotPackedType, + Value slotIndex, + Location loc) { + FailureOr firstOffset = + scaleIndexByDim0SizeInClass(state, targetClass, slotIndex, slotPackedType.getDimSize(0), loc); + if (failed(firstOffset)) + return failure(); + return createDim0ExtractSliceInClass(state, targetClass, loc, packed, *firstOffset, slotPackedType.getDimSize(0)); +} + +SmallVector flattenPackedScalarRunKeys(const PackedScalarRunValue& run) { + SmallVector keys; + for (const PackedScalarRunSlot& slot : run.slots) + llvm::append_range(keys, slot.keys); + return keys; +} + +bool packedScalarRunSlotsMatch(const PackedScalarRunValue& lhs, const PackedScalarRunValue& rhs) { + if (lhs.slots.size() != rhs.slots.size()) + return false; + + for (auto [lhsSlot, rhsSlot] : llvm::zip(lhs.slots, rhs.slots)) { + if (lhsSlot.keys.size() != rhsSlot.keys.size()) + return false; + if (!llvm::equal(lhsSlot.keys, rhsSlot.keys)) + return false; + } + + return true; +} + + +bool appendConstantChannelReceiveMessage(MessageVector& messages, SpatChannelReceiveOp receive) { + std::optional channelId = getConstantIndexValue(receive.getChannelId()); + std::optional sourceCoreId = getConstantIndexValue(receive.getSourceCoreId()); + std::optional targetCoreId = getConstantIndexValue(receive.getTargetCoreId()); + if (!channelId || !sourceCoreId || !targetCoreId) + return false; + messages.append(*channelId, static_cast(*sourceCoreId), static_cast(*targetCoreId)); + return true; +} + +PackedScalarRunValue* findDeferredReceiveAlternativeForPackedRun(MaterializerState& state, + const MaterializedClass& targetClass, + const PackedScalarRunValue& run) { + WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(run.sourceOp, run.resultIndex, targetClass.id); + ArrayRef runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey); + + for (size_t runIndex : runIndices) { + PackedScalarRunValue& candidate = state.availableValues.getPackedRun(runIndex); + if (&candidate == &run || candidate.kind != PackedScalarRunKind::DeferredReceive) + continue; + if (candidate.fragmentType != run.fragmentType) + continue; + if (!packedScalarRunSlotsMatch(candidate, run)) + continue; + return &candidate; + } + + return nullptr; +} + +FailureOr emitIndexedFragmentInsertLoop(MaterializerState& state, + MaterializedClass& targetClass, + Value destination, + int64_t itemCount, + IndexedFragmentBuilder buildFragment, + IndexedInsertOffsetBuilder buildOffset, + Location loc) { + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, itemCount); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + Operation* insertionPoint = targetClass.body->getTerminator(); + + state.rewriter.setInsertionPoint(insertionPoint); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {destination}, + [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + FailureOr fragment = buildFragment(flatIndex); + if (failed(fragment)) + return failure(); + FailureOr offset = buildOffset(flatIndex); + if (failed(offset)) + return failure(); + FailureOr next = + insertFragmentIntoWholeBatch(state, targetClass, *fragment, iterArgs.front(), *offset, loc); + if (failed(next)) + return failure(); + yielded.push_back(*next); + return success(); + }); + if (failed(loop)) + return failure(); + return loop->results.front(); +} + +FailureOr> cloneBatchBodyForLane(MaterializerState& state, + MaterializedClass& targetClass, + const ComputeInstance& instance, + Value laneValue, + ArrayRef resultIndices, + CloneIndexingContext indexing = {}); + +Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targetClass, Value slotIndex, Location loc); +FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, + MaterializedClass& targetClass, + IndexedBatchRunValue& run, + Value runSlotIndex, + Location loc); + +FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& state, + MaterializedClass& targetClass, + PackedScalarRunValue& run, + Location loc) { + assert(isDeferredLocalPackedScalarRun(run) && "expected deferred local packed scalar run"); + + SmallVector keys = flattenPackedScalarRunKeys(run); + if (keys.empty()) + return failure(); + FailureOr packedType = getPackedBatchTensorType(run.fragmentType, keys.size()); + if (failed(packedType)) + return targetClass.op->emitError("cannot materialize deferred local packed run for non-static ranked tensor"); + + SmallVector sourceLanes; + sourceLanes.reserve(keys.size()); + for (ProducerKey key : keys) { + if (key.instance.laneCount != 1) + return failure(); + sourceLanes.push_back(key.instance.laneStart); + } + + SmallVector resultIndices {run.resultIndex}; + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value init = + tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult(); + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(keys.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {init}, + [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value acc = iterArgs.front(); + Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); + + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + keys.front().instance, + sourceLane, + resultIndices, + CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = loopIndex}); + if (failed(produced) || produced->size() != 1) + return failure(); + + FailureOr firstOffset = + scaleIndexByDim0SizeInClass(state, targetClass, loopIndex, run.fragmentType.getDimSize(0), loc); + if (failed(firstOffset)) + return failure(); + FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, produced->front(), acc, *firstOffset); + if (failed(next)) + return failure(); + yielded.push_back(*next); + return success(); + }); + if (failed(loop)) + return failure(); + run.packed = loop->results.front(); + return run.packed; +} + +LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state, + MaterializedClass& targetClass, + ProducerKey key, + WholeBatchAssemblyPlan& plan) { + WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id); + ArrayRef runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey); + + for (size_t runIndex : runIndices) { + PackedScalarRunValue& run = state.availableValues.getPackedRun(runIndex); + + SmallVector runKeys; + SmallVector runCoveredLanes(plan.batchLaneCount, 0); + + for (const PackedScalarRunSlot& slot : run.slots) { + for (ProducerKey fragmentKey : slot.keys) { + if (fragmentKey.instance.op != key.instance.op || fragmentKey.resultIndex != key.resultIndex) + return failure(); + + if (fragmentKey.instance.laneCount == 0) + return failure(); + + if (wholeBatchRangeOverlaps(plan, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount)) + return failure(); + + if (localLaneRangeOverlaps(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount)) + return failure(); + + markLocalLaneRangeCovered(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount); + runKeys.push_back(fragmentKey); + } + } + + if (runKeys.empty()) + continue; + + plan.packedRuns.push_back(&run); + + for (ProducerKey runKey : runKeys) + recordWholeBatchCoverage(plan, runKey.instance.laneStart, runKey.instance.laneCount); + } + + return success(); +} + +LogicalResult collectDirectFragmentsForWholeBatchInput(MaterializerState& state, + MaterializedClass& targetClass, + SpatComputeBatch batch, + ProducerKey key, + WholeBatchAssemblyPlan& plan) { + struct CandidateFragment { + ProducerKey key; + Value value; + }; + + uint32_t batchLaneCount = static_cast(batch.getLaneCount()); + if (plan.coveredLaneCount == plan.batchLaneCount) { + return success(); + } + + WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id); + ArrayRef indexedFragments = + state.availableValues.getExactFragmentsForWholeBatch(lookupKey); + + SmallVector candidates; + candidates.reserve(indexedFragments.size()); + for (const AvailableValueStore::ExactBatchFragmentRecord& record : indexedFragments) { + ProducerKey candidateKey = record.key; + if (candidateKey.instance.op != batch.getOperation() || candidateKey.resultIndex != key.resultIndex + || candidateKey.instance.laneCount == 0) + continue; + if (!isTensorValueLocalToMaterializedClass(record.value, targetClass)) + continue; + if (wholeBatchRangeOverlaps(plan, candidateKey.instance.laneStart, candidateKey.instance.laneCount)) + continue; + + auto fragmentType = dyn_cast(record.value.getType()); + if (!fragmentType) + continue; + + int64_t expectedRows = plan.rowsPerLane * static_cast(candidateKey.instance.laneCount); + if (failed(validateWholeBatchFragmentType(plan.resultType, fragmentType, expectedRows))) + continue; + + candidates.push_back({candidateKey, record.value}); + } + + llvm::sort(candidates, [](const CandidateFragment& lhs, const CandidateFragment& rhs) { + if (lhs.key.instance.laneStart != rhs.key.instance.laneStart) + return lhs.key.instance.laneStart < rhs.key.instance.laneStart; + return lhs.key.instance.laneCount > rhs.key.instance.laneCount; + }); + + size_t candidateCursor = 0; + uint32_t lane = 0; + while (lane < batchLaneCount) { + while (lane < batchLaneCount && wholeBatchLaneCovered(plan, lane)) { + ++lane; + } + + if (lane >= batchLaneCount) + break; + + while (candidateCursor < candidates.size() && candidates[candidateCursor].key.instance.laneStart < lane) + ++candidateCursor; + + size_t candidateIndex = candidateCursor; + const CandidateFragment* best = nullptr; + while (candidateIndex < candidates.size() && candidates[candidateIndex].key.instance.laneStart == lane) { + const CandidateFragment& candidate = candidates[candidateIndex]; + if (!wholeBatchRangeOverlaps(plan, lane, candidate.key.instance.laneCount)) { + best = &candidate; + break; + } + ++candidateIndex; + } + + if (!best) + return failure(); + + plan.directFragments.push_back({best->key, best->value}); + recordWholeBatchCoverage(plan, lane, best->key.instance.laneCount); + lane += best->key.instance.laneCount; + } + + return success(); +} + +LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state, + MaterializedClass& targetClass, + const WholeBatchAssemblyPlan& plan, + SmallVectorImpl& groups) { + for (PackedScalarRunValue* run : plan.packedRuns) { + if (!run || run->slots.empty()) + continue; + if (run->fragmentType.getDimSize(0) != plan.rowsPerLane) + return failure(); + + if (run->kind == PackedScalarRunKind::Materialized && run->packed + && !isTensorValueLocalToMaterializedClass(run->packed, targetClass)) { + if (PackedScalarRunValue* deferredRun = findDeferredReceiveAlternativeForPackedRun(state, targetClass, *run)) + run = deferredRun; + else { + SmallVector keys = flattenPackedScalarRunKeys(*run); + std::optional packedKey = getContiguousProducerRangeForKeys(keys); + emitNonLocalMaterializedClassValueDiagnostic(targetClass.op, + targetClass, + "whole-batch assembly tried to reuse non-local PackedValue", + run->packed, + packedKey); + return failure(); + } + } + + if (run->kind == PackedScalarRunKind::DeferredReceive) { + if (failed(validatePackedScalarRunMetadata(targetClass.op, *run))) + return failure(); + + auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { + return group.kind == WholeBatchFragmentSourceKind::DeferredReceive && group.fragmentType == run->fragmentType; + }); + if (groupIt == groups.end()) { + WholeBatchFragmentGroup group; + group.kind = WholeBatchFragmentSourceKind::DeferredReceive; + group.fragmentType = run->fragmentType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + + groupIt->messages.append(run->messages.channelIds, run->messages.sourceCoreIds, run->messages.targetCoreIds); + for (const PackedScalarRunSlot& slot : run->slots) + for (ProducerKey fragmentKey : slot.keys) + groupIt->outputOffsets.push_back(static_cast(fragmentKey.instance.laneStart) * plan.rowsPerLane); + continue; + } + + if (run->kind == PackedScalarRunKind::DeferredLocalCompute) { + SmallVector keys = flattenPackedScalarRunKeys(*run); + if (keys.empty()) + return failure(); + + auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { + return group.kind == WholeBatchFragmentSourceKind::DeferredLocalCompute + && group.fragmentType == run->fragmentType && group.sourceOp == run->sourceOp + && group.resultIndex == run->resultIndex; + }); + if (groupIt == groups.end()) { + WholeBatchFragmentGroup group; + group.kind = WholeBatchFragmentSourceKind::DeferredLocalCompute; + group.fragmentType = run->fragmentType; + group.sourceOp = run->sourceOp; + group.resultIndex = run->resultIndex; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + + for (ProducerKey fragmentKey : keys) { + if (fragmentKey.instance.laneCount != 1) + return failure(); + groupIt->sourceLanes.push_back(fragmentKey.instance.laneStart); + groupIt->outputOffsets.push_back(static_cast(fragmentKey.instance.laneStart) * plan.rowsPerLane); + } + continue; + } + + auto sourceBatch = dyn_cast_or_null(run->sourceOp); + if (!sourceBatch || !run->packed) + return failure(); + + auto getOrCreatePackedValueGroup = [&](RankedTensorType slotPackedType) -> WholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { + return group.kind == WholeBatchFragmentSourceKind::PackedValue && group.fragmentType == run->fragmentType + && group.packed == run->packed && group.slotPackedType == slotPackedType; + }); + if (groupIt == groups.end()) { + WholeBatchFragmentGroup group; + group.kind = WholeBatchFragmentSourceKind::PackedValue; + group.fragmentType = run->fragmentType; + group.packed = run->packed; + group.slotPackedType = slotPackedType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + + size_t flattenedIndexBase = 0; + for (auto [slotIndex, slot] : llvm::enumerate(run->slots)) { + std::optional contiguousKey = getPhysicallyContiguousProducerRangeForKeys(slot.keys); + if (contiguousKey) { + FailureOr slotPackedType = getPackedBatchTensorType(run->fragmentType, slot.keys.size()); + if (failed(slotPackedType)) + return failure(); + WholeBatchFragmentGroup& group = getOrCreatePackedValueGroup(*slotPackedType); + group.slotIndices.push_back(slotIndex); + group.outputOffsets.push_back(static_cast(contiguousKey->instance.laneStart) * plan.rowsPerLane); + flattenedIndexBase += slot.keys.size(); + continue; + } + + WholeBatchFragmentGroup& group = getOrCreatePackedValueGroup(run->fragmentType); + for (auto [keyIndex, fragmentKey] : llvm::enumerate(slot.keys)) { + group.slotIndices.push_back(flattenedIndexBase + keyIndex); + group.outputOffsets.push_back(static_cast(fragmentKey.instance.laneStart) * plan.rowsPerLane); + } + flattenedIndexBase += slot.keys.size(); + } + } + + auto getOrCreateDeferredReceiveGroup = [&](RankedTensorType fragmentType) -> WholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { + return group.kind == WholeBatchFragmentSourceKind::DeferredReceive && group.fragmentType == fragmentType; + }); + if (groupIt == groups.end()) { + WholeBatchFragmentGroup group; + group.kind = WholeBatchFragmentSourceKind::DeferredReceive; + group.fragmentType = fragmentType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + + auto getOrCreateDirectValueGroup = [&](RankedTensorType fragmentType) -> WholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { + return group.kind == WholeBatchFragmentSourceKind::DirectValue && group.fragmentType == fragmentType; + }); + if (groupIt == groups.end()) { + WholeBatchFragmentGroup group; + group.kind = WholeBatchFragmentSourceKind::DirectValue; + group.fragmentType = fragmentType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + + for (const DirectWholeBatchFragment& fragment : plan.directFragments) { + if (!isTensorValueLocalToMaterializedClass(fragment.fragment, targetClass)) { + emitNonLocalMaterializedClassValueDiagnostic(targetClass.op, + targetClass, + "whole-batch assembly tried to reuse non-local DirectValue", + fragment.fragment, + fragment.key); + return failure(); + } + + auto fragmentType = dyn_cast(fragment.fragment.getType()); + if (!fragmentType) + return failure(); + + int64_t outputOffset = static_cast(fragment.key.instance.laneStart) * plan.rowsPerLane; + + if (auto receive = fragment.fragment.getDefiningOp()) { + if (fragment.fragment.use_empty()) { + WholeBatchFragmentGroup& group = getOrCreateDeferredReceiveGroup(fragmentType); + if (appendConstantChannelReceiveMessage(group.messages, receive)) { + group.outputOffsets.push_back(outputOffset); + group.redundantReceives.push_back(receive.getOperation()); + continue; + } + } + } + + WholeBatchFragmentGroup& group = getOrCreateDirectValueGroup(fragmentType); + group.directFragments.push_back({fragment.fragment, outputOffset}); + } + + return success(); +} + +FailureOr emitWholeBatchFragmentGroup(MaterializerState& state, + MaterializedClass& targetClass, + Value destination, + const WholeBatchFragmentGroup& group, + Location loc) { + switch (group.kind) { + case WholeBatchFragmentSourceKind::DeferredReceive: { + FailureOr updated = emitIndexedFragmentInsertLoop( + state, + targetClass, + destination, + static_cast(group.outputOffsets.size()), + [&](Value flatIndex) -> FailureOr { + Value channelId = createIndexedChannelId(state, targetClass.op, group.messages, flatIndex, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, group.messages, flatIndex, loc); + Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, group.messages, flatIndex, loc); + return SpatChannelReceiveOp::create( + state.rewriter, loc, group.fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); + }, + [&](Value flatIndex) -> FailureOr { + return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); + }, + loc); + if (failed(updated)) + return failure(); + + for (Operation* receive : group.redundantReceives) + if (receive && receive->use_empty()) + receive->erase(); + + return *updated; + } + case WholeBatchFragmentSourceKind::DeferredLocalCompute: { + SmallVector resultIndices {group.resultIndex}; + return emitIndexedFragmentInsertLoop( + state, + targetClass, + destination, + static_cast(group.outputOffsets.size()), + [&](Value flatIndex) -> FailureOr { + Value sourceLane = createIndexedIndexValue(state, targetClass.op, group.sourceLanes, flatIndex, loc); + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + ComputeInstance {group.sourceOp, 0, 1}, + sourceLane, + resultIndices, + CloneIndexingContext {.runSlotIndex = flatIndex, .projectionSlotIndex = flatIndex}); + if (failed(produced) || produced->size() != 1) + return failure(); + return produced->front(); + }, + [&](Value flatIndex) -> FailureOr { + return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); + }, + loc); + } + case WholeBatchFragmentSourceKind::PackedValue: + return emitIndexedFragmentInsertLoop( + state, + targetClass, + destination, + static_cast(group.slotIndices.size()), + [&](Value flatIndex) -> FailureOr { + Value packedSlotIndex = createIndexedIndexValue(state, targetClass.op, group.slotIndices, flatIndex, loc); + FailureOr packed = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + group.packed, + targetClass.op, + "whole-batch packed fragment assembly tried to reuse a tensor from another materialized class"); + if (failed(packed)) + return failure(); + return extractPackedSlotForIndex(state, targetClass, *packed, group.slotPackedType, packedSlotIndex, loc); + }, + [&](Value flatIndex) -> FailureOr { + return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); + }, + loc); + case WholeBatchFragmentSourceKind::DirectValue: + for (const auto& [fragment, offset] : group.directFragments) { + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + FailureOr localFragment = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + fragment, + targetClass.op, + "whole-batch direct fragment assembly tried to reuse a tensor from another materialized class"); + if (failed(localFragment)) + return failure(); + FailureOr updated = createDim0InsertSliceInClass(state, + targetClass, + loc, + *localFragment, + destination, + getOrCreateIndexConstant(state.constantFolder, targetClass.op, offset)); + if (failed(updated)) + return failure(); + destination = *updated; + } + return destination; + } + + return failure(); +} + +FailureOr emitProjectedWholeBatchFragmentInsertLoop( + MaterializerState& state, + MaterializedClass& targetClass, + Value destination, + const ProjectedWholeBatchFragmentGroup& group, + llvm::function_ref(Value)> buildFragment, + Location loc) { + assert(group.fragmentType && "expected projected fragment type"); + assert(!group.offsetsByDim.empty() && "expected projected insert coordinates"); + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, group.offsetsByDim.front().size()); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {destination}, + [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + FailureOr fragment = buildFragment(flatIndex); + if (failed(fragment)) + return failure(); + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + unsigned rank = group.offsetsByDim.size(); + offsets.reserve(rank); + sizes.reserve(rank); + strides.reserve(rank); + for (unsigned dim = 0; dim < rank; ++dim) { + offsets.push_back(createIndexedOrStaticIndex(state, targetClass.op, group.offsetsByDim[dim], flatIndex, loc)); + sizes.push_back(createIndexedOrStaticIndex(state, targetClass.op, group.sizesByDim[dim], flatIndex, loc)); + strides.push_back(createIndexedOrStaticIndex(state, targetClass.op, group.stridesByDim[dim], flatIndex, loc)); + } + + Value updated = + tensor::InsertSliceOp::create(state.rewriter, loc, *fragment, iterArgs.front(), offsets, sizes, strides) + .getResult(); + yielded.push_back(updated); + return success(); + }); + if (failed(loop)) + return failure(); + return loop->results.front(); +} + +std::optional getStaticProjectedPackedFragmentIndex(tensor::ExtractSliceOp extract) { + auto sourceType = dyn_cast(extract.getSource().getType()); + auto resultType = dyn_cast(extract.getResult().getType()); + if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() + || sourceType.getRank() == 0 || sourceType.getRank() != resultType.getRank()) + return std::nullopt; + + std::optional firstOffset = getConstantIndex(extract.getMixedOffsets().front()); + if (!firstOffset) + return std::nullopt; + + for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { + std::optional offset = getConstantIndex(extract.getMixedOffsets()[dim]); + std::optional size = getConstantIndex(extract.getMixedSizes()[dim]); + std::optional stride = getConstantIndex(extract.getMixedStrides()[dim]); + if (!offset || !size || !stride || *stride != 1 || *size != resultType.getDimSize(dim)) + return std::nullopt; + if (dim != 0 && *offset != 0) + return std::nullopt; + } + + return *firstOffset; +} + +void appendProjectedInsertCoordinates(ProjectedWholeBatchFragmentGroup& group, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + if (group.offsetsByDim.empty()) { + size_t rank = offsets.size(); + group.offsetsByDim.resize(rank); + group.sizesByDim.resize(rank); + group.stridesByDim.resize(rank); + } + + for (size_t dim = 0; dim < offsets.size(); ++dim) { + group.offsetsByDim[dim].push_back(offsets[dim]); + group.sizesByDim[dim].push_back(sizes[dim]); + group.stridesByDim[dim].push_back(strides[dim]); + } +} + +FailureOr buildWholeBatchAssemblyPlan(MaterializerState& state, + MaterializedClass& targetClass, + ProducerKey key, + Type resultType) { + auto batch = dyn_cast_or_null(key.instance.op); + auto resultTensorType = dyn_cast(resultType); + if (!batch || !resultTensorType || !resultTensorType.hasStaticShape() || resultTensorType.getRank() == 0) + return failure(); + + uint32_t batchLaneCount = static_cast(batch.getLaneCount()); + if (batchLaneCount == 0 || resultTensorType.getDimSize(0) % static_cast(batchLaneCount) != 0) + return failure(); + + WholeBatchAssemblyPlan plan; + plan.resultType = resultTensorType; + plan.rowsPerLane = resultTensorType.getDimSize(0) / static_cast(batchLaneCount); + plan.batchLaneCount = batchLaneCount; + plan.coveredLanes.assign(batchLaneCount, 0); + + if (failed(collectPackedRunsForWholeBatchInput(state, targetClass, key, plan))) + return failure(); + + if (plan.coveredLaneCount == plan.batchLaneCount) + return plan; + + if (failed(collectDirectFragmentsForWholeBatchInput(state, targetClass, batch, key, plan))) + return failure(); + + return plan; +} + +FailureOr emitWholeBatchAssemblyPlan(MaterializerState& state, + MaterializedClass& targetClass, + ProducerKey key, + WholeBatchAssemblyPlan& plan, + Location loc) { + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value result = + tensor::EmptyOp::create(state.rewriter, loc, plan.resultType.getShape(), plan.resultType.getElementType()) + .getResult(); + + SmallVector groups; + if (failed(collectWholeBatchFragmentGroups(state, targetClass, plan, groups))) + return failure(); + + for (const WholeBatchFragmentGroup& group : groups) { + FailureOr updated = emitWholeBatchFragmentGroup(state, targetClass, result, group, loc); + if (failed(updated)) + return failure(); + result = *updated; + } + + state.availableValues.record(key, targetClass.id, result); + return result; +} + +// ----------------------------------------------------------------------------- +// Run materialization helpers. +// ----------------------------------------------------------------------------- + +FailureOr materializeProjectedWholeBatchInputFromFragments(MaterializerState& state, + MaterializedClass& targetClass, + ProducerKey key, + Type resultType, + Location loc) { + auto batch = dyn_cast_or_null(key.instance.op); + auto resultTensorType = dyn_cast(resultType); + if (!batch || !resultTensorType || !resultTensorType.hasStaticShape()) + return failure(); + + FailureOr projection = getBatchResultProjectionInsert(batch, key.resultIndex); + if (failed(projection)) + return failure(); + + auto laneArg = batch.getLaneArgument(); + if (!laneArg) + return batch.emitOpError("missing compute_batch lane argument while materializing projected whole-batch input"); + + uint32_t laneEnd = key.instance.laneStart + key.instance.laneCount; + if (laneEnd > static_cast(batch.getLaneCount())) + return failure(); + + if (targetClass.isBatch) { + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value result = + tensor::EmptyOp::create(state.rewriter, loc, resultTensorType.getShape(), resultTensorType.getElementType()) + .getResult(); + + for (uint32_t lane = key.instance.laneStart; lane < laneEnd; ++lane) { + ProducerKey laneKey = getBatchLaneProducerKey(batch, lane, 1, key.resultIndex); + std::optional fragment = state.availableValues.lookup(state, laneKey, targetClass.id); + if (!fragment) + return failure(); + + FailureOr> offsets = + evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, lane); + FailureOr> sizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, lane); + FailureOr> strides = + evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, lane); + if (failed(offsets) || failed(sizes) || failed(strides)) + return failure(); + + SmallVector offsetAttrs; + SmallVector sizeAttrs; + SmallVector strideAttrs; + offsetAttrs.reserve(offsets->size()); + sizeAttrs.reserve(sizes->size()); + strideAttrs.reserve(strides->size()); + for (auto [offset, size, stride] : llvm::zip(*offsets, *sizes, *strides)) { + offsetAttrs.push_back(state.rewriter.getIndexAttr(offset)); + sizeAttrs.push_back(state.rewriter.getIndexAttr(size)); + strideAttrs.push_back(state.rewriter.getIndexAttr(stride)); + } + + FailureOr localFragment = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + *fragment, + targetClass.op, + "projected whole-batch assembly tried to reuse a tensor from another materialized class", + laneKey); + if (failed(localFragment)) + return failure(); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + result = tensor::InsertSliceOp::create( + state.rewriter, loc, *localFragment, result, offsetAttrs, sizeAttrs, strideAttrs) + .getResult(); + } + + state.availableValues.record(key, targetClass.id, result); + return result; + } + + SmallVector groups; + auto getOrCreateReceiveGroup = [&](RankedTensorType fragmentType) -> ProjectedWholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const ProjectedWholeBatchFragmentGroup& group) { + return group.kind == ProjectedWholeBatchFragmentSourceKind::DeferredReceive && group.fragmentType == fragmentType; + }); + if (groupIt == groups.end()) { + ProjectedWholeBatchFragmentGroup group; + group.kind = ProjectedWholeBatchFragmentSourceKind::DeferredReceive; + group.fragmentType = fragmentType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + auto getOrCreatePackedGroup = [&](Value packed, + RankedTensorType packedSourceType, + RankedTensorType fragmentType) -> ProjectedWholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const ProjectedWholeBatchFragmentGroup& group) { + return group.kind == ProjectedWholeBatchFragmentSourceKind::PackedValue && group.fragmentType == fragmentType + && group.packed == packed && group.packedSourceType == packedSourceType; + }); + if (groupIt == groups.end()) { + ProjectedWholeBatchFragmentGroup group; + group.kind = ProjectedWholeBatchFragmentSourceKind::PackedValue; + group.fragmentType = fragmentType; + group.packed = packed; + group.packedSourceType = packedSourceType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + auto getOrCreateDirectGroup = [&](RankedTensorType fragmentType) -> ProjectedWholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const ProjectedWholeBatchFragmentGroup& group) { + return group.kind == ProjectedWholeBatchFragmentSourceKind::DirectValue && group.fragmentType == fragmentType; + }); + if (groupIt == groups.end()) { + ProjectedWholeBatchFragmentGroup group; + group.kind = ProjectedWholeBatchFragmentSourceKind::DirectValue; + group.fragmentType = fragmentType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + + for (uint32_t lane = key.instance.laneStart; lane < laneEnd; ++lane) { + ProducerKey laneKey = getBatchLaneProducerKey(batch, lane, 1, key.resultIndex); + FailureOr> offsets = + evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, lane); + FailureOr> sizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, lane); + FailureOr> strides = + evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, lane); + if (failed(offsets) || failed(sizes) || failed(strides)) + return failure(); + + bool grouped = false; + if (std::optional exact = state.availableValues.lookupExact(laneKey, targetClass.id)) { + if (auto receive = exact->getDefiningOp()) { + auto fragmentType = dyn_cast(receive.getOutput().getType()); + if (fragmentType && receive.getOutput().use_empty()) { + ProjectedWholeBatchFragmentGroup& group = getOrCreateReceiveGroup(fragmentType); + if (appendConstantChannelReceiveMessage(group.messages, receive)) { + appendProjectedInsertCoordinates(group, *offsets, *sizes, *strides); + group.redundantOps.push_back(receive.getOperation()); + grouped = true; + } + } + } + } + + if (grouped) + continue; + + std::optional fragment = state.availableValues.lookup(state, laneKey, targetClass.id); + if (!fragment) + return failure(); + + auto fragmentType = dyn_cast(fragment->getType()); + if (!fragmentType) + return failure(); + + if (auto extract = fragment->getDefiningOp()) { + if (std::optional packedIndex = getStaticProjectedPackedFragmentIndex(extract)) { + auto packedSourceType = dyn_cast(extract.getSource().getType()); + if (packedSourceType) { + ProjectedWholeBatchFragmentGroup& group = + getOrCreatePackedGroup(extract.getSource(), packedSourceType, fragmentType); + group.packedIndices.push_back(*packedIndex); + appendProjectedInsertCoordinates(group, *offsets, *sizes, *strides); + group.redundantOps.push_back(extract.getOperation()); + continue; + } + } + } + + ProjectedWholeBatchFragmentGroup& group = getOrCreateDirectGroup(fragmentType); + group.directFragments.push_back({*fragment, *offsets, *sizes, *strides}); + } + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value result = + tensor::EmptyOp::create(state.rewriter, loc, resultTensorType.getShape(), resultTensorType.getElementType()) + .getResult(); + + for (const ProjectedWholeBatchFragmentGroup& group : groups) { + FailureOr updated = failure(); + switch (group.kind) { + case ProjectedWholeBatchFragmentSourceKind::DeferredReceive: + updated = emitProjectedWholeBatchFragmentInsertLoop( + state, + targetClass, + result, + group, + [&](Value flatIndex) -> FailureOr { + Value channelId = createIndexedChannelId(state, targetClass.op, group.messages, flatIndex, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, group.messages, flatIndex, loc); + Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, group.messages, flatIndex, loc); + return SpatChannelReceiveOp::create( + state.rewriter, loc, group.fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); + }, + loc); + break; + case ProjectedWholeBatchFragmentSourceKind::PackedValue: + updated = emitProjectedWholeBatchFragmentInsertLoop( + state, + targetClass, + result, + group, + [&](Value flatIndex) -> FailureOr { + SmallVector extractOffsets; + SmallVector extractSizes; + SmallVector extractStrides; + extractOffsets.reserve(group.packedSourceType.getRank()); + extractSizes.reserve(group.packedSourceType.getRank()); + extractStrides.reserve(group.packedSourceType.getRank()); + extractOffsets.push_back(createIndexedOrStaticIndex( + state, targetClass.op, group.packedIndices, flatIndex, loc)); + extractSizes.push_back(state.rewriter.getIndexAttr(1)); + extractStrides.push_back(state.rewriter.getIndexAttr(1)); + for (int64_t dim = 1; dim < group.packedSourceType.getRank(); ++dim) { + extractOffsets.push_back(state.rewriter.getIndexAttr(0)); + extractSizes.push_back(state.rewriter.getIndexAttr(group.packedSourceType.getDimSize(dim))); + extractStrides.push_back(state.rewriter.getIndexAttr(1)); + } + + FailureOr packed = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + group.packed, + targetClass.op, + "projected whole-batch packed fragment assembly tried to reuse a tensor from another materialized class"); + if (failed(packed)) + return failure(); + + return tensor::ExtractSliceOp::create( + state.rewriter, + loc, + group.fragmentType, + *packed, + extractOffsets, + extractSizes, + extractStrides) + .getResult(); + }, + loc); + break; + case ProjectedWholeBatchFragmentSourceKind::DirectValue: { + updated = result; + for (const ProjectedWholeBatchDirectFragment& fragment : group.directFragments) { + FailureOr localFragment = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + fragment.fragment, + targetClass.op, + "projected whole-batch assembly tried to reuse a tensor from another materialized class"); + if (failed(localFragment)) + return failure(); + + SmallVector offsetAttrs; + SmallVector sizeAttrs; + SmallVector strideAttrs; + for (auto [offset, size, stride] : llvm::zip(fragment.offsets, fragment.sizes, fragment.strides)) { + offsetAttrs.push_back(state.rewriter.getIndexAttr(offset)); + sizeAttrs.push_back(state.rewriter.getIndexAttr(size)); + strideAttrs.push_back(state.rewriter.getIndexAttr(stride)); + } + updated = tensor::InsertSliceOp::create( + state.rewriter, loc, *localFragment, *updated, offsetAttrs, sizeAttrs, strideAttrs) + .getResult(); + } + break; + } + } + if (failed(updated)) + return failure(); + result = *updated; + } + + for (const ProjectedWholeBatchFragmentGroup& group : groups) + for (Operation* redundantOp : group.redundantOps) + if (redundantOp && redundantOp->use_empty()) + redundantOp->erase(); + + state.availableValues.record(key, targetClass.id, result); + return result; +} + +FailureOr materializeWholeBatchInput( + MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, Type resultType, Location loc) { + if (failed(materializePendingScalarReceivesForWholeBatchInput(state, targetClass, key, loc))) + return failure(); + + FailureOr plan = buildWholeBatchAssemblyPlan(state, targetClass, key, resultType); + if (succeeded(plan)) + return emitWholeBatchAssemblyPlan(state, targetClass, key, *plan, loc); + + return materializeProjectedWholeBatchInputFromFragments(state, targetClass, key, resultType, loc); +} + +FailureOr resolveInputValue(MaterializerState& state, + MaterializedClass& targetClass, + Value input, + const ComputeInstance& consumerInstance, + CloneIndexingContext indexing) { + auto rejectNonLocalResolvedValue = [&](Value resolved) -> FailureOr { + if (!isTensorValueDefinedInDifferentMaterializedClass(resolved, targetClass)) + return resolved; + + std::optional producer = getInputRequestProducerKey(input, consumerInstance); + emitNonLocalMaterializedClassValueDiagnostic(consumerInstance.op, + targetClass, + "input resolution tried to reuse a tensor from another materialized class", + resolved, + producer); + return failure(); + }; + + if (isConstantLike(input)) + return input; + + if (std::optional producer = getInputRequestProducerKey(input, consumerInstance)) { + if (indexing.runSlotIndex) { + if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) { + FailureOr received = materializeIndexedBatchRunReceive( + state, targetClass, *indexedRun, *indexing.runSlotIndex, consumerInstance.op->getLoc()); + if (failed(received)) + return failure(); + return rejectNonLocalResolvedValue(*received); + } + } + + if (std::optional value = state.availableValues.lookup(state, *producer, targetClass.id)) + return rejectNonLocalResolvedValue(*value); + + if (auto pendingReceive = lookupPendingScalarReceiveIndex(state, *producer, targetClass.id)) { + FailureOr received = + materializePendingScalarReceive(state, targetClass, *pendingReceive, consumerInstance.op->getLoc()); + if (failed(received)) + return failure(); + return rejectNonLocalResolvedValue(*received); + } + + if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) { + size_t laneCount = targetClass.cpus.size(); + for (auto [slotIndex, slot] : llvm::enumerate(indexedRun->slots)) { + if (!llvm::is_contained(slot.keys, *producer)) + continue; + + MessageVector messages = indexedRun->messages.slice(slotIndex * laneCount, laneCount); + Value received = + appendReceive(state, targetClass, indexedRun->fragmentType, messages, consumerInstance.op->getLoc()); + for (ProducerKey slotKey : slot.keys) + state.availableValues.record(slotKey, targetClass.id, received); + return rejectNonLocalResolvedValue(received); + } + } + + if (isWholeBatchProducerKey(*producer)) { + FailureOr wholeBatch = + materializeWholeBatchInput(state, targetClass, *producer, input.getType(), consumerInstance.op->getLoc()); + if (failed(wholeBatch)) + consumerInstance.op->emitError("failed to materialize whole-batch input") + << " from '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart + << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; + if (failed(wholeBatch)) + return failure(); + return rejectNonLocalResolvedValue(*wholeBatch); + } + + consumerInstance.op->emitError("failed to resolve producer value") + << " from op '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart + << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; + return failure(); + } + + if (isTensorValueDefinedInDifferentMaterializedClass(input, targetClass)) { + emitNonLocalMaterializedClassValueDiagnostic( + consumerInstance.op, + targetClass, + "input resolution tried to append a tensor from another materialized class as a normal input", + input); + return failure(); + } + + return appendInput(state, targetClass, input); +} + +bool hasProjectedInputReplacement(MaterializerState& state, + SpatComputeBatch batch, + unsigned inputIndex, + ClassId classId) { + std::optional match = getProjectedInputSliceMatch(state, batch, inputIndex); + if (!match) + return false; + + auto replacementIt = state.projectedExtractReplacements.find(match->extract.getOperation()); + if (replacementIt == state.projectedExtractReplacements.end()) + return false; + + return replacementIt->second.find(classId) != replacementIt->second.end(); +} + +void mapWeights(MaterializerState& state, + MaterializedClass& targetClass, + const ComputeInstance& instance, + IRMapping& mapper) { + Operation* op = instance.op; + if (auto compute = dyn_cast(op)) { + for (auto [index, weight] : llvm::enumerate(compute.getWeights())) { + auto weightArg = compute.getWeightArgument(index); + assert(weightArg && "expected compute weight block argument"); + mapper.map(*weightArg, appendWeight(state, targetClass, weight)); + } + return; + } + + auto batch = cast(op); + for (auto [index, weight] : llvm::enumerate(batch.getWeights())) { + auto weightArg = batch.getWeightArgument(index); + assert(weightArg && "expected compute_batch weight block argument"); + mapper.map(*weightArg, appendWeight(state, targetClass, weight)); + } +} + +LogicalResult mapInputs(MaterializerState& state, + MaterializedClass& targetClass, + const ComputeInstance& instance, + IRMapping& mapper, + CloneIndexingContext indexing) { + auto mapResolvedInput = [&](Value resolved) -> FailureOr { + return materializeTensorValueForMaterializedClassUse( + state, + targetClass, + resolved, + targetClass.op, + "input mapping tried to reuse a tensor from another materialized class"); + }; + + Operation* op = instance.op; + if (auto compute = dyn_cast(op)) { + for (auto [index, input] : llvm::enumerate(compute.getInputs())) { + FailureOr mapped = resolveInputValue(state, targetClass, input, instance, indexing); + if (failed(mapped)) { + std::optional producer = getInputRequestProducerKey(input, instance); + auto diagnostic = compute.emitOpError("failed to resolve materialized compute input") << " #" << index; + if (producer) { + diagnostic << " from '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart + << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; + } + return failure(); + } + auto inputArg = compute.getInputArgument(index); + if (!inputArg) + return compute.emitOpError("expected compute input block argument while materializing inputs"); + FailureOr remapped = mapResolvedInput(*mapped); + if (failed(remapped)) { + emitNonLocalMaterializedClassValueDiagnostic(compute, + targetClass, + "mapInputs tried to append a tensor from another materialized class", + *mapped, + getInputRequestProducerKey(input, instance)); + return failure(); + } + mapper.map(*inputArg, *remapped); + } + return success(); + } + + auto batch = cast(op); + for (auto [index, input] : llvm::enumerate(batch.getInputs())) { + if (hasProjectedInputReplacement(state, batch, static_cast(index), targetClass.id)) + continue; + + FailureOr mapped = resolveInputValue(state, targetClass, input, instance, indexing); + if (failed(mapped)) + return batch.emitOpError("failed to resolve materialized compute_batch input"); + auto inputArg = batch.getInputArgument(index); + if (!inputArg) + return batch.emitOpError("expected compute_batch input block argument while materializing inputs"); + FailureOr remapped = mapResolvedInput(*mapped); + if (failed(remapped)) { + emitNonLocalMaterializedClassValueDiagnostic(batch, + targetClass, + "mapInputs tried to append a tensor from another materialized class", + *mapped, + getInputRequestProducerKey(input, instance)); + return failure(); + } + mapper.map(*inputArg, *remapped); + } + return success(); +} + +SmallVector collectMappedBatchOutputs(SpatComputeBatch batch, IRMapping& mapper) { + SmallVector outputs(batch.getNumResults(), Value {}); + auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); + if (!inParallel) + return outputs; + + for (Operation& op : inParallel.getRegion().front()) { + auto insert = dyn_cast(&op); + if (!insert) + continue; + + auto outputArg = dyn_cast(insert.getDest()); + if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) + continue; + + auto firstOutputArg = batch.getOutputArgument(0); + if (!firstOutputArg) + return outputs; + unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); + if (resultIndex >= outputs.size()) + continue; + outputs[resultIndex] = mapper.lookupOrDefault(insert.getSource()); + } + + return outputs; +} + +SmallVector collectBatchOutputFragmentTypes(SpatComputeBatch batch) { + SmallVector types(batch.getNumResults(), Type {}); + auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); + if (!inParallel) + return types; + + auto firstOutputArg = batch.getOutputArgument(0); + if (!firstOutputArg) + return types; + + for (Operation& op : inParallel.getRegion().front()) { + auto insert = dyn_cast(&op); + if (!insert) + continue; + + auto outputArg = dyn_cast(insert.getDest()); + if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) + continue; + + unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); + if (resultIndex >= types.size()) + continue; + + types[resultIndex] = insert.getSource().getType(); + } + + return types; +} + +SmallVector& getBatchOutputFragmentTypesCached(MaterializerState& state, SpatComputeBatch batch) { + auto [it, inserted] = state.batchOutputFragmentTypesCache.try_emplace(batch.getOperation(), SmallVector {}); + if (inserted) + it->second = collectBatchOutputFragmentTypes(batch); + return it->second; +} + +ArrayRef getComputeInstanceOutputValuesCached(MaterializerState& state, ComputeInstance instance) { + auto [it, inserted] = state.computeInstanceOutputsCache.try_emplace(instance, SmallVector {}); + if (inserted) + it->second = getComputeInstanceOutputValues(instance); + return it->second; +} + +std::optional lookupProjectedExtractReplacement(MaterializerState& state, + MaterializedClass& targetClass, + tensor::ExtractSliceOp extract) { + auto replacementIt = state.projectedExtractReplacements.find(extract.getOperation()); + if (replacementIt == state.projectedExtractReplacements.end()) + return std::nullopt; + + auto classIt = replacementIt->second.find(targetClass.id); + if (classIt == replacementIt->second.end()) + return std::nullopt; + + return classIt->second; +} + +bool requiresConstantProjectionSlotIndex(MaterializerState& state, + MaterializedClass& targetClass, + Operation* sourceOp) { + bool requiresConstantIndex = false; + sourceOp->walk([&](tensor::ExtractSliceOp extract) { + if (requiresConstantIndex) + return WalkResult::interrupt(); + + std::optional replacement = + lookupProjectedExtractReplacement(state, targetClass, extract); + if (!replacement) + return WalkResult::advance(); + + if (replacement->layout.payloadFragmentCount != replacement->layout.fragmentsPerLogicalSlot) { + requiresConstantIndex = true; + return WalkResult::interrupt(); + } + + return WalkResult::advance(); + }); + return requiresConstantIndex; +} + +LogicalResult applyProjectedExtractReplacementsInClonedOp(MaterializerState& state, + MaterializedClass& targetClass, + Operation& originalOp, + Operation& clonedOp, + CloneIndexingContext indexing, + IRMapping& mapper) { + if (auto originalExtract = dyn_cast(&originalOp)) { + if (std::optional replacement = + lookupProjectedExtractReplacement(state, targetClass, originalExtract)) { + auto clonedExtract = dyn_cast(&clonedOp); + if (!clonedExtract) + return targetClass.op->emitError("projected replacement lost extract structure during cloning"); + + state.rewriter.setInsertionPoint(clonedExtract); + FailureOr projected = materializeProjectedExtractReplacement( + state, targetClass, clonedExtract, *replacement, indexing.projectionSlotIndex, &mapper); + if (failed(projected)) + return failure(); + + clonedExtract.getResult().replaceAllUsesWith(*projected); + state.rewriter.eraseOp(clonedExtract); + return success(); + } + } + + if (originalOp.getNumRegions() != clonedOp.getNumRegions()) + return targetClass.op->emitError("projected replacement traversal found non-isomorphic cloned regions"); + + for (auto [originalRegion, clonedRegion] : llvm::zip(originalOp.getRegions(), clonedOp.getRegions())) { + if (std::distance(originalRegion.begin(), originalRegion.end()) + != std::distance(clonedRegion.begin(), clonedRegion.end())) + return targetClass.op->emitError("projected replacement traversal found non-isomorphic cloned blocks"); + + for (auto [originalBlock, clonedBlock] : llvm::zip(originalRegion.getBlocks(), clonedRegion.getBlocks())) { + auto originalIt = originalBlock.begin(); + auto clonedIt = clonedBlock.begin(); + while (originalIt != originalBlock.end() && clonedIt != clonedBlock.end()) { + Operation& originalNestedOp = *originalIt++; + Operation* currentClonedOp = &*clonedIt++; + if (failed(applyProjectedExtractReplacementsInClonedOp( + state, targetClass, originalNestedOp, *currentClonedOp, indexing, mapper))) + return failure(); + } + if (originalIt != originalBlock.end() || clonedIt != clonedBlock.end()) + return targetClass.op->emitError("projected replacement traversal found mismatched cloned operations"); + } + } + + return success(); +} + +LogicalResult mapClonedRegionBlockArguments(Operation& originalOp, Operation& clonedOp, IRMapping& mapper) { + if (originalOp.getNumRegions() != clonedOp.getNumRegions()) + return clonedOp.emitError("cloned operation has a different number of regions than the source operation"); + + for (auto [originalRegion, clonedRegion] : llvm::zip(originalOp.getRegions(), clonedOp.getRegions())) { + if (std::distance(originalRegion.begin(), originalRegion.end()) + != std::distance(clonedRegion.begin(), clonedRegion.end())) + return clonedOp.emitError("cloned operation has a different number of blocks than the source operation"); + + for (auto [originalBlock, clonedBlock] : llvm::zip(originalRegion.getBlocks(), clonedRegion.getBlocks())) { + if (originalBlock.getNumArguments() != clonedBlock.getNumArguments()) + return clonedOp.emitError("cloned operation block has a different number of arguments than the source block"); + + for (auto [originalArg, clonedArg] : llvm::zip(originalBlock.getArguments(), clonedBlock.getArguments())) + if (!mapper.contains(originalArg)) + mapper.map(originalArg, clonedArg); + + if (std::distance(originalBlock.begin(), originalBlock.end()) != std::distance(clonedBlock.begin(), clonedBlock.end())) + return clonedOp.emitError("cloned operation block has a different number of operations than the source block"); + + auto originalIt = originalBlock.begin(); + auto clonedIt = clonedBlock.begin(); + while (originalIt != originalBlock.end()) { + Operation& originalNestedOp = *originalIt++; + Operation& clonedNestedOp = *clonedIt++; + if (failed(mapClonedRegionBlockArguments(originalNestedOp, clonedNestedOp, mapper))) + return failure(); + } + } + } + + return success(); +} + +LogicalResult cloneComputeTemplateBody(MaterializerState& state, + MaterializedClass& targetClass, + const ComputeInstance& instance, + IRMapping& mapper, + CloneIndexingContext indexing) { + Block& sourceBlock = getComputeInstanceTemplateBlock(instance); + for (Operation& op : sourceBlock.without_terminator()) { + if (auto extract = dyn_cast(&op)) { + if (std::optional replacement = + lookupProjectedExtractReplacement(state, targetClass, extract)) { + FailureOr projected = materializeProjectedExtractReplacement( + state, targetClass, extract, *replacement, indexing.projectionSlotIndex, &mapper); + if (failed(projected)) + return failure(); + + mapper.map(extract.getResult(), *projected); + continue; + } + } + + for (Value operand : op.getOperands()) { + if (mapper.contains(operand)) + continue; + + FailureOr localized = localizeMaterializedClassOperand( + state, + targetClass, + operand, + &op, + "cloneComputeTemplateBody tried to reuse a tensor from another materialized class", + "cloneComputeTemplateBody produced an unsupported external non-tensor operand", + &mapper); + if (failed(localized)) + return failure(); + if (*localized != operand) + mapper.map(operand, *localized); + } + + Operation* cloned = state.rewriter.clone(op, mapper); + if (failed(mapClonedRegionBlockArguments(op, *cloned, mapper))) + return failure(); + if (failed(localizeCapturesInClonedOp(state, targetClass, *cloned, &mapper))) + return failure(); + if (op.getNumRegions() != 0 + && failed(applyProjectedExtractReplacementsInClonedOp(state, targetClass, op, *cloned, indexing, mapper))) + return failure(); + for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults())) + mapper.map(oldResult, newResult); + } + + return success(); +} + +FailureOr materializeProjectedExtractReplacement(MaterializerState& state, + MaterializedClass& targetClass, + tensor::ExtractSliceOp extract, + const ProjectedExtractReplacement& replacement, + std::optional projectionSlotIndex, + IRMapping* mapper) { + if (failed(verifyProjectedFragmentLayout(targetClass.op, replacement.layout))) + return failure(); + + FailureOr localizedPayload = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + replacement.payload, + targetClass.op, + "projected extract replacement tried to reuse a tensor from another materialized class", + std::nullopt, + mapper); + if (failed(localizedPayload)) + return failure(); + Value payload = *localizedPayload; + + if (replacement.layout.payloadFragmentCount == 1) + return payload; + + if (replacement.layout.payloadFragmentCount < replacement.layout.fragmentsPerLogicalSlot) + return targetClass.op->emitError("projected replacement payload is smaller than one logical slot"); + + Value intraSlotFragmentIndex = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + const auto linearizeProjectedLoopIndices = [&]() -> FailureOr { + if (replacement.layout.loopTripCounts.empty()) + return intraSlotFragmentIndex; + + SmallVector surroundingLoops; + for (Operation* current = extract->getParentOp(); current; current = current->getParentOp()) { + if (auto loop = dyn_cast(current)) + surroundingLoops.push_back(loop); + if (current == targetClass.op) + break; + } + std::reverse(surroundingLoops.begin(), surroundingLoops.end()); + + if (surroundingLoops.size() != replacement.layout.loopTripCounts.size()) + return targetClass.op->emitError("projected replacement loop structure does not match the collected descriptor"); + + Value linearizedIndex = intraSlotFragmentIndex; + for (auto [index, loop] : llvm::enumerate(surroundingLoops)) { + FailureOr localizedIv = + rematerializeIndexValueInClass(state, targetClass, loop.getInductionVar(), extract.getLoc(), mapper); + if (failed(localizedIv)) + return failure(); + Value iv = *localizedIv; + Value lowerBound = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopLowerBounds[index]); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopSteps[index]); + Value tripCount = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopTripCounts[index]); + + Value normalized = arith::SubIOp::create(state.rewriter, extract.getLoc(), iv, lowerBound).getResult(); + if (replacement.layout.loopSteps[index] != 1) + normalized = arith::DivUIOp::create(state.rewriter, extract.getLoc(), normalized, step).getResult(); + linearizedIndex = arith::MulIOp::create(state.rewriter, extract.getLoc(), linearizedIndex, tripCount).getResult(); + linearizedIndex = + arith::AddIOp::create(state.rewriter, extract.getLoc(), linearizedIndex, normalized).getResult(); + } + return linearizedIndex; + }; + + FailureOr linearizedIndex = linearizeProjectedLoopIndices(); + if (failed(linearizedIndex)) + return failure(); + intraSlotFragmentIndex = *linearizedIndex; + + const auto computeProjectedPayloadFragmentIndex = [&]() -> FailureOr { + if (replacement.layout.payloadFragmentCount == replacement.layout.fragmentsPerLogicalSlot) { + if (replacement.layout.loopTripCounts.empty() && replacement.layout.fragmentsPerLogicalSlot != 1) + return targetClass.op->emitError("projected replacement is missing loop metadata for packed logical slot"); + return intraSlotFragmentIndex; + } + + if (!projectionSlotIndex) + return targetClass.op->emitError("packed projected extract replacement requires a fragment slot index"); + + FailureOr localProjectionSlotIndex = + rematerializeIndexValueInClass(state, targetClass, *projectionSlotIndex, extract.getLoc(), mapper); + if (failed(localProjectionSlotIndex)) + return failure(); + + Value fragmentsPerLogicalSlot = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.fragmentsPerLogicalSlot); + Value base = + arith::MulIOp::create(state.rewriter, extract.getLoc(), *localProjectionSlotIndex, fragmentsPerLogicalSlot) + .getResult(); + return arith::AddIOp::create(state.rewriter, extract.getLoc(), base, intraSlotFragmentIndex).getResult(); + }; + + FailureOr packedFragmentIndex = computeProjectedPayloadFragmentIndex(); + if (failed(packedFragmentIndex)) + return failure(); + + FailureOr packedOffset = scaleIndexByDim0SizeInClass( + state, targetClass, *packedFragmentIndex, replacement.layout.fragmentType.getDimSize(0), extract.getLoc()); + if (failed(packedOffset)) + return failure(); + return createDim0ExtractSliceInClass( + state, targetClass, extract.getLoc(), payload, *packedOffset, replacement.layout.fragmentType.getDimSize(0)); +} + +FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, + MaterializedClass& targetClass, + IndexedBatchRunValue& run, + Value runSlotIndex, + Location loc) { + if (!targetClass.isBatch) + return targetClass.op->emitError("indexed batch run receive requires a batch target class"); + if (failed(run.messages.verify(targetClass.op))) + return failure(); + + Value flatIndex = createBatchRunFlatIndex(state, targetClass, runSlotIndex, loc); + std::optional preferredPeriod = static_cast(targetClass.cpus.size()); + Value channelId = createIndexedChannelId(state, targetClass.op, run.messages, flatIndex, loc, preferredPeriod); + Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, run.messages, flatIndex, loc, preferredPeriod); + Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, run.messages, flatIndex, loc, preferredPeriod); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + return SpatChannelReceiveOp::create(state.rewriter, loc, run.fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); +} + +LogicalResult localizeCapturesInOperationTree(MaterializerState& state, + MaterializedClass& targetClass, + Operation& root, + StringRef tensorContext, + StringRef genericContext, + IRMapping* mapper = nullptr) { + WalkResult walkResult = root.walk([&](Operation* nestedOp) -> WalkResult { + for (OpOperand& operand : nestedOp->getOpOperands()) { + Value current = operand.get(); + if (isValueLegalInMaterializedClassBody(current, targetClass)) + continue; + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPoint(nestedOp); + FailureOr localized = + localizeMaterializedClassOperand(state, targetClass, current, nestedOp, tensorContext, genericContext, mapper); + if (failed(localized)) { + InFlightDiagnostic diagnostic = targetClass.op->emitError( + "RAPTOR_MATERIALIZER_DEBUG failed to localize cloned scheduled-body operand"); + diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName() + << "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType() + << " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp)) + << "\" offendingOperands=\"" << formatMaterializerOperandListInline(nestedOp, targetClass) + << "\" parentChain=\"" << formatMaterializerParentChainInline(nestedOp) << "\""; + diagnostic.attachNote(nestedOp->getLoc()) << "offending nested operation"; + attachMaterializerOperationPrintNote(diagnostic, nestedOp, "RAPTOR_MATERIALIZER_DEBUG offending nested operation IR"); + attachMaterializerOperandListNote(diagnostic, nestedOp, targetClass, "RAPTOR_MATERIALIZER_DEBUG offending nested operation operands"); + attachMaterializerParentChainNote(diagnostic, nestedOp, "RAPTOR_MATERIALIZER_DEBUG offending nested operation parent chain"); + attachMaterializerValueOriginNote(diagnostic, current, "offending operand"); + attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "RAPTOR_MATERIALIZER_DEBUG target materialized op"); + attachMaterializedClassBodySummary(diagnostic, targetClass); + return WalkResult::interrupt(); + } + operand.set(*localized); + } + return WalkResult::advance(); + }); + + return walkResult.wasInterrupted() ? failure() : success(); +} + +LogicalResult localizeCapturesInClonedOp(MaterializerState& state, + MaterializedClass& targetClass, + Operation& clonedOp, + IRMapping* mapper) { + return localizeCapturesInOperationTree( + state, + targetClass, + clonedOp, + "cloneComputeTemplateBody tried to reuse a tensor from another materialized class", + "cloneComputeTemplateBody produced an unsupported external non-tensor operand", + mapper); +} + +LogicalResult localizeAllScheduledBodyCaptures(MaterializerState& state, MaterializedClass& targetClass) { + SmallVector bodyOps; + for (Operation& op : *targetClass.body) + op.walk([&](Operation* nestedOp) { bodyOps.push_back(nestedOp); }); + + for (Operation* nestedOp : bodyOps) { + if (nestedOp->getBlock() == nullptr) + continue; + for (OpOperand& operand : nestedOp->getOpOperands()) { + Value current = operand.get(); + if (isValueLegalInMaterializedClassBody(current, targetClass)) + continue; + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPoint(nestedOp); + FailureOr localized = localizeMaterializedClassOperand( + state, + targetClass, + current, + nestedOp, + "final scheduled body capture localization tried to reuse a tensor from another materialized class", + "final scheduled body capture localization found an unsupported external non-tensor operand"); + if (failed(localized)) { + InFlightDiagnostic diagnostic = targetClass.op->emitError( + "RAPTOR_MATERIALIZER_DEBUG failed to localize final scheduled-body operand"); + diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName() + << "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType() + << " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp)) + << "\" offendingOperands=\"" << formatMaterializerOperandListInline(nestedOp, targetClass) + << "\" parentChain=\"" << formatMaterializerParentChainInline(nestedOp) << "\""; + diagnostic.attachNote(nestedOp->getLoc()) << "offending nested operation"; + attachMaterializerValueOriginNote(diagnostic, current, "offending operand"); + attachMaterializedClassBodySummary(diagnostic, targetClass); + return failure(); + } + operand.set(*localized); + } + } + + return success(); +} + +FailureOr> cloneInstanceBody(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef peers, + CloneIndexingContext indexing) { + assert(!peers.empty() && "expected at least one peer instance"); + const ComputeInstance& instance = peers.front(); + Operation* sourceOp = instance.op; + Location loc = sourceOp->getLoc(); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + + IRMapping mapper; + if (auto batch = dyn_cast(sourceOp)) { + for (const ComputeInstance& peer : peers) { + if (peer.op != sourceOp) { + sourceOp->emitError("equivalence class slot contains different source compute_batch operations"); + return failure(); + } + } + auto laneArg = batch.getLaneArgument(); + if (!laneArg) { + sourceOp->emitError("expected source compute_batch lane block argument"); + return failure(); + } + mapper.map(*laneArg, createOriginalLaneValue(state, targetClass, peers, loc)); + } + + OpBuilder::InsertPoint cloneInsertionPoint = state.rewriter.saveInsertionPoint(); + + mapWeights(state, targetClass, instance, mapper); + if (failed(mapInputs(state, targetClass, instance, mapper, indexing))) + return failure(); + + state.rewriter.restoreInsertionPoint(cloneInsertionPoint); + if (failed(cloneComputeTemplateBody(state, targetClass, instance, mapper, indexing))) + return failure(); + + if (auto compute = dyn_cast(sourceOp)) { + Block& sourceBlock = getComputeInstanceTemplateBlock(instance); + auto yield = dyn_cast_or_null(sourceBlock.getTerminator()); + if (!yield) { + compute.emitOpError("expected spat.yield terminator while materializing compute"); + return failure(); + } + + SmallVector outputs; + outputs.reserve(yield.getNumOperands()); + for (Value yielded : yield.getOutputs()) + outputs.push_back(mapper.lookupOrDefault(yielded)); + return outputs; + } + + auto batch = cast(sourceOp); + if (batch.getNumResults() == 0) + return SmallVector {}; + + SmallVector outputs = collectMappedBatchOutputs(batch, mapper); + for (Value output : outputs) + if (!output) { + batch.emitOpError("failed to recover yielded per-lane value for compute_batch result"); + return failure(); + } + return outputs; +} + +bool sameDestinationClasses(ArrayRef lhs, ArrayRef rhs) { + if (lhs.size() != rhs.size()) + return false; + for (auto [lhsClass, rhsClass] : llvm::zip(lhs, rhs)) + if (lhsClass != rhsClass) + return false; + return true; +} + +SmallVector +collectDestinationClassesForRun(MaterializerState& state, ArrayRef run, size_t resultIndex) { + SmallVector destinations; + + for (const MaterializationRunSlot& slot : run) { + for (const ComputeInstance& peer : slot.peers) { + ProducerKey key {peer, resultIndex}; + for (ClassId destinationClass : getDestinationClasses(state, key)) + if (!llvm::is_contained(destinations, destinationClass)) + destinations.push_back(destinationClass); + } + } + + llvm::sort(destinations); + return destinations; +} + +SmallVector groupBatchRunOutputsByDestination(MaterializerState& state, + ArrayRef run) { + assert(!run.empty() && "expected non-empty materialization run"); + assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); + + SmallVector groups; + ArrayRef outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front()); + + for (auto [resultIndex, output] : llvm::enumerate(outputs)) { + SmallVector destinations = collectDestinationClassesForRun(state, run, resultIndex); + + auto existingGroup = llvm::find_if(groups, [&](const OutputDestinationGroup& group) { + return sameDestinationClasses(group.destinationClasses, destinations); + }); + + if (existingGroup != groups.end()) { + existingGroup->resultIndices.push_back(resultIndex); + continue; + } + + OutputDestinationGroup group; + group.resultIndices.push_back(resultIndex); + group.destinationClasses = std::move(destinations); + groups.push_back(std::move(group)); + } + + return groups; +} + +FailureOr getPackedRunTensorType(Type elementType, size_t runSize) { + auto tensorType = dyn_cast(elementType); + if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0) + return failure(); + + SmallVector shape(tensorType.getShape()); + shape[0] *= static_cast(runSize); + return RankedTensorType::get(shape, tensorType.getElementType(), tensorType.getEncoding()); +} + +LogicalResult registerDeferredLocalPackedRunValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef keys, + Type fragmentType, + Location loc) { + if (keys.empty()) + return success(); + + auto rankedFragmentType = dyn_cast(fragmentType); + if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) + return materializedClass.op->emitError("deferred local packed run expects static ranked fragment type"); + + Operation* sourceOp = keys.front().instance.op; + size_t resultIndex = keys.front().resultIndex; + + for (ProducerKey key : keys) { + if (key.instance.op != sourceOp || key.resultIndex != resultIndex) + return materializedClass.op->emitError("deferred local packed run expects one producer result"); + + if (key.instance.laneCount != 1) + return materializedClass.op->emitError("deferred local packed run expects one lane per fragment"); + } + + PackedScalarRunValue packedRun; + packedRun.targetClass = materializedClass.id; + packedRun.sourceOp = sourceOp; + packedRun.resultIndex = resultIndex; + packedRun.kind = PackedScalarRunKind::DeferredLocalCompute; + packedRun.fragmentType = rankedFragmentType; + + packedRun.slots.reserve(keys.size()); + for (ProducerKey key : keys) { + PackedScalarRunSlot slot; + slot.keys.push_back(key); + packedRun.slots.push_back(std::move(slot)); + } + + state.availableValues.recordPackedRun(std::move(packedRun)); + return success(); +} + +LogicalResult registerPackedRunValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef keys, + Value packed, + Type fragmentType, + Location loc) { + if (keys.empty()) + return success(); + + FailureOr expectedPackedType = getPackedRunTensorType(fragmentType, keys.size()); + if (failed(expectedPackedType)) + return materializedClass.op->emitError("packed run registration expects static ranked fragment type"); + + if (packed.getType() != *expectedPackedType) + return materializedClass.op->emitError("packed run value has unexpected tensor type"); + + auto rankedFragmentType = dyn_cast(fragmentType); + if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) + return materializedClass.op->emitError("packed run registration expects static ranked fragment type"); + + Operation* sourceOp = keys.front().instance.op; + size_t resultIndex = keys.front().resultIndex; + + for (ProducerKey key : keys) { + if (key.instance.op != sourceOp || key.resultIndex != resultIndex) + return materializedClass.op->emitError("packed run registration expects one producer result"); + if (key.instance.laneCount != 1) + return materializedClass.op->emitError("packed run registration expects one lane per packed fragment"); + } + + if (std::optional contiguousKey = getContiguousProducerRangeForKeys(keys)) { + state.availableValues.record(*contiguousKey, materializedClass.id, packed); + return success(); + } + + PackedScalarRunValue packedRun; + packedRun.targetClass = materializedClass.id; + packedRun.sourceOp = sourceOp; + packedRun.resultIndex = resultIndex; + packedRun.packed = packed; + packedRun.kind = PackedScalarRunKind::Materialized; + packedRun.fragmentType = rankedFragmentType; + + packedRun.slots.reserve(keys.size()); + for (ProducerKey key : keys) { + PackedScalarRunSlot slot; + slot.keys.push_back(key); + packedRun.slots.push_back(std::move(slot)); + } + + state.availableValues.recordPackedRun(std::move(packedRun)); + return success(); +} + +LogicalResult emitPackedRunFanout(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef destinationClasses, + ArrayRef keys, + Value packed, + Type fragmentType, + Location loc) { + assert(!sourceClass.isBatch && "packed run fanout expects a scalar source class"); + + auto fanoutPlan = buildScalarSourceFanoutPlan(state, sourceClass, keys, destinationClasses, packed); + if (failed(fanoutPlan)) + return failure(); + if (failed(emitScalarSourceFanoutSends(state, sourceClass, packed, *fanoutPlan, loc))) + return failure(); + + for (const ScalarSourceReceivePlan& plan : fanoutPlan->receivePlans) { + MaterializedClass& targetClass = state.classes[plan.targetClass]; + + Value received = appendReceive(state, targetClass, plan.receiveType, plan.messages, loc); + + if (plan.projectedExtractOp) { + state.projectedExtractReplacements[plan.projectedExtractOp][plan.targetClass] = + ProjectedExtractReplacement {received, plan.projectedLayout}; + continue; + } + + if (failed(registerPackedRunValue(state, targetClass, keys, received, fragmentType, loc))) + return failure(); + } + + return success(); +} + +FailureOr> cloneBatchBodyForLane(MaterializerState& state, + MaterializedClass& targetClass, + const ComputeInstance& instance, + Value laneValue, + ArrayRef resultIndices, + CloneIndexingContext indexing) { + auto batch = dyn_cast(instance.op); + if (!batch) + return failure(); + + IRMapping mapper; + auto sourceLaneArg = batch.getLaneArgument(); + if (!sourceLaneArg) + return batch.emitOpError("expected source compute_batch lane block argument"); + + mapper.map(*sourceLaneArg, laneValue); + + OpBuilder::InsertPoint cloneInsertionPoint = state.rewriter.saveInsertionPoint(); + + mapWeights(state, targetClass, instance, mapper); + if (failed(mapInputs(state, targetClass, instance, mapper, indexing))) + return failure(); + + state.rewriter.restoreInsertionPoint(cloneInsertionPoint); + if (failed(cloneComputeTemplateBody(state, targetClass, instance, mapper, indexing))) + return failure(); + + SmallVector allOutputs = collectMappedBatchOutputs(batch, mapper); + if (allOutputs.empty() && !resultIndices.empty()) + return batch.emitOpError("failed to recover source compute_batch outputs"); + + SmallVector selectedOutputs; + selectedOutputs.reserve(resultIndices.size()); + for (size_t resultIndex : resultIndices) { + if (resultIndex >= allOutputs.size() || !allOutputs[resultIndex]) + return batch.emitOpError("failed to recover selected compute_batch output"); + selectedOutputs.push_back(allOutputs[resultIndex]); + } + + return selectedOutputs; +} + +FailureOr> materializeBatchOutputGroupLoop(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run, + const OutputDestinationGroup& group) { + assert(!run.empty() && "expected non-empty batch run"); + assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); + + Operation* sourceOp = run.front().peers.front().op; + Location loc = sourceOp->getLoc(); + + if (run.size() == 1) { + if (run.front().peers.size() != 1) + return sourceOp->emitError("scalar batch output loop expects exactly one peer in singleton slot"); + + const ComputeInstance& item = run.front().peers.front(); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value laneValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, item.laneStart); + return cloneBatchBodyForLane(state, targetClass, item, laneValue, group.resultIndices, {}); + } + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + + auto sourceBatch = cast(sourceOp); + SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); + SmallVector initValues; + for (size_t resultIndex : group.resultIndices) { + if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) + return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run"); + + Type fragmentType = fragmentTypes[resultIndex]; + FailureOr packedType = getPackedRunTensorType(fragmentType, run.size()); + if (failed(packedType)) + return sourceBatch.emitOpError("cannot materialize packed batch run for non-static ranked output"); + + initValues.push_back( + tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult()); + } + + SmallVector logicalLanes; + logicalLanes.reserve(run.size()); + for (const MaterializationRunSlot& slot : run) { + if (slot.peers.size() != 1) + return sourceOp->emitError("scalar batch output loop expects exactly one peer per materialization slot"); + + const ComputeInstance& item = slot.peers.front(); + if (item.op != sourceOp) + return sourceOp->emitError("materialization run contains different source operations"); + + logicalLanes.push_back(item.laneStart); + } + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange(initValues), + [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc); + + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + run.front().peers.front(), + sourceLane, + group.resultIndices, + CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex}); + if (failed(produced)) + return failure(); + + yielded.reserve(produced->size()); + for (auto [outputIndex, output] : llvm::enumerate(*produced)) { + auto fragmentType = cast(output.getType()); + Value acc = iterArgs[outputIndex]; + FailureOr firstOffset = + scaleIndexByDim0SizeInClass(state, targetClass, loopIndex, fragmentType.getDimSize(0), loc); + if (failed(firstOffset)) + return failure(); + FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, output, acc, *firstOffset); + if (failed(next)) + return failure(); + yielded.push_back(*next); + } + return success(); + }); + if (failed(loop)) + return failure(); + + SmallVector results; + results.reserve(loop->results.size()); + for (Value result : loop->results) + results.push_back(result); + return results; +} + +SmallVector getMaterializationRunSlotOutputKeys(const MaterializationRunSlot& slot, + size_t resultIndex) { + SmallVector keys; + keys.reserve(slot.peers.size()); + for (const ComputeInstance& peer : slot.peers) + keys.push_back({peer, resultIndex}); + return keys; +} + +FailureOr> +getMaterializationRunSlotPeers(MaterializerState& state, MaterializedClass& targetClass, SlotId logicalSlot) { + if (targetClass.isBatch) + return getPeerLogicalInstances(state, targetClass, logicalSlot); + + auto streamIt = state.logicalInstancesByCpu.find(targetClass.cpus.front()); + if (streamIt == state.logicalInstancesByCpu.end() || logicalSlot >= streamIt->second.size()) + return failure(); + + return SmallVector {streamIt->second[logicalSlot]}; +} + +FailureOr collectBatchMaterializationRun(MaterializerState& state, + MaterializedClass& targetClass, + SlotId startSlot, + Operation* sourceOp) { + MaterializationRun run; + + for (SlotId slot = startSlot;; ++slot) { + ClassSlotKey classSlot {targetClass.id, slot}; + if (state.materializedLogicalSlots.contains(classSlot)) + break; + + FailureOr> peers = getMaterializationRunSlotPeers(state, targetClass, slot); + if (failed(peers) || peers->empty()) + break; + + bool validSlot = true; + for (const ComputeInstance& peer : *peers) { + if (peer.op != sourceOp || !isa(peer.op)) { + validSlot = false; + break; + } + } + + if (!validSlot) + break; + + MaterializationRunSlot runSlot; + runSlot.peers = std::move(*peers); + run.push_back(std::move(runSlot)); + } + + if (run.empty()) + return failure(); + + return run; +} + +SmallVector getMaterializationRunOutputKeys(ArrayRef run, size_t resultIndex) { + SmallVector keys; + for (const MaterializationRunSlot& slot : run) + llvm::append_range(keys, getMaterializationRunSlotOutputKeys(slot, resultIndex)); + return keys; +} + +ArrayRef getFirstMaterializationRunOriginalOutputs(MaterializerState& state, + ArrayRef run) { + assert(!run.empty() && "expected non-empty materialization run"); + assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); + return getComputeInstanceOutputValuesCached(state, run.front().peers.front()); +} + +Operation* getMaterializationRunSourceOp(ArrayRef run) { + assert(!run.empty() && "expected non-empty materialization run"); + assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); + return run.front().peers.front().op; +} + +Location getMaterializationRunLoc(ArrayRef run) { + return getMaterializationRunSourceOp(run)->getLoc(); +} + +bool hasMaterializationRunResultLiveExternalUse(MaterializerState& state, + ArrayRef run, + size_t resultIndex) { + for (const MaterializationRunSlot& slot : run) { + for (const ComputeInstance& peer : slot.peers) { + ArrayRef outputs = getComputeInstanceOutputValuesCached(state, peer); + if (resultIndex >= outputs.size()) + return true; + + if (hasLiveExternalUseCached(state, outputs[resultIndex])) + return true; + } + } + + return false; +} + +bool hasMaterializationRunGroupLiveExternalUse(MaterializerState& state, + ArrayRef run, + const OutputDestinationGroup& group) { + for (size_t resultIndex : group.resultIndices) + if (hasMaterializationRunResultLiveExternalUse(state, run, resultIndex)) + return true; + + return false; +} + +bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId); + +bool hasMaterializationRunGroupSameClassConsumer(MaterializerState& state, + ClassId classId, + ArrayRef run, + const OutputDestinationGroup& group) { + for (size_t resultIndex : group.resultIndices) { + for (const MaterializationRunSlot& slot : run) { + for (const ComputeInstance& peer : slot.peers) + if (hasSameClassConsumer(state, {peer, resultIndex}, classId)) + return true; + } + } + + return false; +} + +bool canRegisterDeferredLocalPackedRun(MaterializerState& state, ArrayRef run) { + for (const MaterializationRunSlot& slot : run) { + for (const ComputeInstance& peer : slot.peers) { + for (Value input : getComputeInstanceInputs(peer)) { + std::optional producer = getInputRequestProducerKey(input, peer); + if (producer && isWholeBatchProducerKey(*producer)) + return false; + } + } + } + + return true; +} + +void markMaterializationRunSlots(MaterializerState& state, + ClassId classId, + SlotId startSlot, + ArrayRef run) { + for (auto slotIndex : llvm::seq(0, run.size())) + state.materializedLogicalSlots.insert({classId, startSlot + static_cast(slotIndex)}); +} + +LogicalResult materializeScalarBatchRun(MaterializerState& state, + MaterializedClass& targetClass, + SlotId startSlot, + ArrayRef run) { + assert(!targetClass.isBatch && "scalar batch run materialization expects scalar target class"); + assert(!run.empty() && "expected non-empty batch run"); + + markMaterializationRunSlots(state, targetClass.id, startSlot, run); + + SmallVector groups = groupBatchRunOutputsByDestination(state, run); + ArrayRef firstOriginalOutputs = getFirstMaterializationRunOriginalOutputs(state, run); + + auto sourceBatch = cast(getMaterializationRunSourceOp(run)); + SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); + Location loc = getMaterializationRunLoc(run); + bool canDeferLocalPackedRun = canRegisterDeferredLocalPackedRun(state, run); + + for (const OutputDestinationGroup& group : groups) { + bool canUseLocalOnlyPackedRun = run.size() > 1 && group.destinationClasses.empty() + && !hasMaterializationRunGroupLiveExternalUse(state, run, group) + && !hasMaterializationRunGroupSameClassConsumer(state, targetClass.id, run, group); + if (canUseLocalOnlyPackedRun && canDeferLocalPackedRun) { + for (size_t resultIndex : group.resultIndices) { + if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) + return sourceBatch.emitOpError("failed to recover per-lane output type for deferred local packed run"); + + SmallVector keys = getMaterializationRunOutputKeys(run, resultIndex); + if (failed(registerDeferredLocalPackedRunValue(state, targetClass, keys, fragmentTypes[resultIndex], loc))) + return failure(); + } + + continue; + } + + FailureOr> packedOutputs = materializeBatchOutputGroupLoop(state, targetClass, run, group); + if (failed(packedOutputs)) + return failure(); + + for (auto [groupOutputIndex, resultIndex] : llvm::enumerate(group.resultIndices)) { + Value packed = (*packedOutputs)[groupOutputIndex]; + if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) + return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run"); + + Type fragmentType = fragmentTypes[resultIndex]; + SmallVector keys = getMaterializationRunOutputKeys(run, resultIndex); + + if (run.size() == 1) { + if (failed(emitOutputFanout(state, targetClass, keys, packed, firstOriginalOutputs[resultIndex], loc))) + return failure(); + continue; + } + + if (canUseLocalOnlyPackedRun) { + if (failed(registerPackedRunValue(state, targetClass, keys, packed, fragmentType, loc))) + return failure(); + continue; + } + + if (failed(emitPackedRunFanout(state, targetClass, group.destinationClasses, keys, packed, fragmentType, loc))) + return failure(); + + if (failed(registerPackedRunValue(state, targetClass, keys, packed, fragmentType, loc))) + return failure(); + + Value representativeOutput = firstOriginalOutputs[resultIndex]; + if (hasLiveExternalUseCached(state, representativeOutput) + && isProjectedTerminalBatchHostOutput(representativeOutput, state.oldComputeOps)) { + std::optional groupedHostPublication = + tryEmitScalarPackedProjectedHostPublication(state, targetClass, keys, packed, representativeOutput, loc); + if (groupedHostPublication) { + if (failed(*groupedHostPublication)) + return failure(); + continue; + } + } + + auto rankedFragmentType = cast(fragmentType); + for (auto [runIndex, slot] : llvm::enumerate(run)) { + assert(slot.peers.size() == 1 && "scalar materialization run slot must contain exactly one peer"); + + ArrayRef originalOutputs = getComputeInstanceOutputValuesCached(state, slot.peers.front()); + Value originalOutput = originalOutputs[resultIndex]; + + if (!hasLiveExternalUseCached(state, originalOutput)) + continue; + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + FailureOr fragment = + getPackedSliceForRunIndex(state, targetClass, packed, rankedFragmentType, runIndex, loc); + if (failed(fragment)) + return failure(); + + if (isProjectedTerminalBatchHostOutput(originalOutput, state.oldComputeOps)) { + ProducerKey key {slot.peers.front(), resultIndex}; + if (failed(emitProjectedBatchHostFragment(state, targetClass, key, *fragment, originalOutput, loc))) + return failure(); + continue; + } + + if (failed(emitHostCommunication(state, targetClass, *fragment, originalOutput))) + return failure(); + } + } + } + + return success(); +} + +bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) { + SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, classId}; + auto it = state.sameClassConsumerIndex.find(lookupKey); + if (it == state.sameClassConsumerIndex.end()) + return false; + + for (ProducerKey existing : it->second) + if (containsProducerKey(existing, producerKey) || containsProducerKey(producerKey, existing)) + return true; + return false; +} + +bool canCompactBatchClassRun(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run) { + if (run.size() < 2) + return false; + if (run.front().peers.empty()) + return false; + + ArrayRef outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front()); + + for (auto [resultIndex, ignored] : llvm::enumerate(outputs)) { + (void) ignored; + for (const MaterializationRunSlot& slot : run) { + if (slot.peers.empty()) + return false; + + for (const ComputeInstance& peer : slot.peers) { + ArrayRef peerOutputs = getComputeInstanceOutputValuesCached(state, peer); + if (resultIndex >= peerOutputs.size()) + return false; + + Value originalOutput = peerOutputs[resultIndex]; + if (hasLiveExternalUseCached(state, originalOutput)) + return false; + + ProducerKey key {peer, resultIndex}; + if (hasSameClassConsumer(state, key, targetClass.id)) + return false; + } + } + } + + return true; +} + +LogicalResult registerMaterializedBatchRunHostOutputs(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run, + const OutputDestinationGroup& group) { + ArrayRef originalOutputs = getFirstMaterializationRunOriginalOutputs(state, run); + for (size_t resultIndex : group.resultIndices) { + if (resultIndex >= originalOutputs.size()) + return targetClass.op->emitError("batch materialization host output index out of range"); + + Value originalOutput = originalOutputs[resultIndex]; + if (!hasLiveExternalUseCached(state, originalOutput)) + continue; + + auto resultIt = targetClass.hostOutputToResultIndex.find(originalOutput); + if (resultIt == targetClass.hostOutputToResultIndex.end()) + return targetClass.op->emitError("missing host result slot for materialized batch output"); + + state.hostReplacements[originalOutput] = targetClass.op->getResult(resultIt->second); + } + + return success(); +} + +LogicalResult verifyMaterializedHostOutputs(MaterializerState& state) { + for (SpatCompute compute : state.func.getOps()) { + auto yieldOp = dyn_cast_or_null(compute.getBody().front().getTerminator()); + if (!yieldOp) + return compute.emitOpError("expected spat.yield terminator in materialized compute"); + if (compute.getNumResults() != yieldOp.getNumOperands()) + return compute.emitOpError("materialized compute result count does not match spat.yield operand count"); + for (auto [result, yielded] : llvm::zip(compute.getResults(), yieldOp.getOperands())) + if (result.getType() != yielded.getType()) + return compute.emitOpError("ComputeOp output must be of the same type as yieldOp operand"); + } + + for (SpatChannelReceiveOp receive : state.func.getOps()) { + if (!receive.getOutput().use_empty()) + continue; + return receive.emitOpError("materialized channel_receive result must have at least one use"); + } + + for (const MaterializedClass& materializedClass : state.classes) { + if (!materializedClass.isBatch || materializedClass.hostOutputs.empty()) + continue; + + auto batch = dyn_cast(materializedClass.op); + auto inParallel = dyn_cast_or_null(materializedClass.body->getTerminator()); + if (!batch || !inParallel) + return materializedClass.op->emitError("expected resultful materialized compute_batch host owner"); + + for (Value hostOutput : materializedClass.hostOutputs) { + auto ownerIt = materializedClass.hostOutputToResultIndex.find(hostOutput); + if (ownerIt == materializedClass.hostOutputToResultIndex.end()) + return materializedClass.op->emitError("missing host result slot for materialized compute_batch host output"); + + auto outputArg = batch.getOutputArgument(ownerIt->second); + if (!outputArg) + return batch.emitOpError("missing output block argument for materialized compute_batch host output"); + + bool foundProjection = false; + for (Operation& op : inParallel.getRegion().front()) { + auto insert = dyn_cast(&op); + if (!insert || insert.getDest() != *outputArg) + continue; + foundProjection = true; + break; + } + + if (!foundProjection) + return batch.emitOpError( + "materialized terminal compute_batch host output is missing tensor.parallel_insert_slice publication"); + } + } + + for (const auto& [originalOutput, replacement] : state.hostReplacements) + if (originalOutput.getType() != replacement.getType()) + return replacement.getDefiningOp()->emitOpError("host output replacement type does not match original output type") + << " replacementType=" << replacement.getType() << " outputType=" << originalOutput.getType(); + + return success(); +} + +Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targetClass, Value slotIndex, Location loc) { + auto batch = cast(targetClass.op); + auto laneArg = batch.getLaneArgument(); + assert(laneArg && "expected materialized compute_batch lane argument"); + + MLIRContext* context = state.func.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + AffineExpr d1 = getAffineDimExpr(1, context); + + int64_t laneCount = static_cast(targetClass.cpus.size()); + AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1); + return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}, state.func); +} + +Value createBatchClassRunSourceLane(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run, + Value slotIndex, + Location loc) { + SmallVector sourceLanes; + sourceLanes.reserve(run.size() * targetClass.cpus.size()); + + for (auto [runSlotIndex, slot] : llvm::enumerate(run)) { + (void) runSlotIndex; + assert(slot.peers.size() == targetClass.cpus.size() && "expected one peer per materialized batch lane"); + for (const ComputeInstance& peer : slot.peers) + sourceLanes.push_back(peer.laneStart); + } + + Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); + return createIndexedIndexValue(state, + targetClass.op, + sourceLanes, + flatIndex, + loc, + static_cast(targetClass.cpus.size()), + /*allowExhaustiveTiledSearch=*/false); +} + +LogicalResult buildBatchRunSendPlans(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef run, + const OutputDestinationGroup& group, + SmallVectorImpl& plans) { + assert(sourceClass.isBatch && "batch run send planning expects a materialized batch source"); + + for (size_t resultIndex : group.resultIndices) { + for (ClassId destinationClass : group.destinationClasses) { + if (destinationClass == sourceClass.id) + return sourceClass.op->emitError("batch-target run compaction cannot handle same-class consumers"); + + MaterializedClass& targetClass = state.classes[destinationClass]; + + if (targetClass.isBatch && targetClass.cpus.size() != sourceClass.cpus.size()) + return sourceClass.op->emitError( + "cannot compact batch run communication between batch classes of different sizes"); + + BatchRunSendPlan plan; + plan.resultIndex = resultIndex; + plan.destinationClass = destinationClass; + + size_t messageCount = run.size() * sourceClass.cpus.size(); + plan.messages.channelIds.reserve(messageCount); + plan.messages.sourceCoreIds.reserve(messageCount); + plan.messages.targetCoreIds.reserve(messageCount); + + for (size_t slotIndex = 0; slotIndex < run.size(); ++slotIndex) { + for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id"); + if (failed(checkedSourceCpu)) + return failure(); + auto checkedTargetCpu = + getCheckedCoreId(targetClass.op, + targetClass.isBatch ? targetClass.cpus[lane] : targetClass.cpus.front(), + "batch run target core id"); + if (failed(checkedTargetCpu)) + return failure(); + plan.messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); + } + (void) slotIndex; + } + + plans.push_back(std::move(plan)); + } + } + + return success(); +} + +void appendBatchRunSend(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const BatchRunSendPlan& plan, + Value flatIndex, + Location loc) { + assert(sourceClass.isBatch && "batch run send expects a materialized batch source"); + + std::optional preferredPeriod = static_cast(sourceClass.cpus.size()); + Value channelId = createIndexedChannelId(state, sourceClass.op, plan.messages, flatIndex, loc, preferredPeriod); + Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, plan.messages, flatIndex, loc, preferredPeriod); + Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, plan.messages, flatIndex, loc, preferredPeriod); + + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); +} + +LogicalResult appendPackedScalarRunReceives(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef run, + const BatchRunSendPlan& plan, + Type fragmentType, + Location loc) { + MaterializedClass& targetClass = state.classes[plan.destinationClass]; + assert(!targetClass.isBatch && "packed scalar run receives expect a scalar target class"); + + size_t laneCount = sourceClass.cpus.size(); + size_t receiveCount = run.size() * laneCount; + + if (failed(plan.messages.verify(targetClass.op))) + return failure(); + + if (receiveCount != plan.messages.size()) + return targetClass.op->emitError("inconsistent flattened batch run receive plan"); + + auto rankedFragmentType = dyn_cast(fragmentType); + if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) + return targetClass.op->emitError("packed scalar run receive expects static ranked fragment type"); + + PackedScalarRunValue packedRun; + packedRun.targetClass = targetClass.id; + packedRun.sourceOp = run.front().peers.front().op; + packedRun.resultIndex = plan.resultIndex; + packedRun.kind = PackedScalarRunKind::DeferredReceive; + packedRun.fragmentType = rankedFragmentType; + + packedRun.messages = plan.messages; + + packedRun.slots.reserve(run.size()); + for (const MaterializationRunSlot& slot : run) { + PackedScalarRunSlot packedSlot; + packedSlot.keys = getMaterializationRunSlotOutputKeys(slot, plan.resultIndex); + packedRun.slots.push_back(std::move(packedSlot)); + } + + if (failed(validatePackedScalarRunMetadata(targetClass.op, packedRun))) + return failure(); + + state.availableValues.recordPackedRun(std::move(packedRun)); + return success(); +} + +LogicalResult recordIndexedBatchRunReceives(MaterializerState& state, + ArrayRef run, + const BatchRunSendPlan& plan, + Type fragmentType) { + MaterializedClass& targetClass = state.classes[plan.destinationClass]; + auto rankedFragmentType = dyn_cast(fragmentType); + if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) + return targetClass.op->emitError("indexed batch run receive expects static ranked fragment type"); + + IndexedBatchRunValue indexedRun; + indexedRun.targetClass = targetClass.id; + indexedRun.sourceOp = run.front().peers.front().op; + indexedRun.resultIndex = plan.resultIndex; + indexedRun.fragmentType = rankedFragmentType; + indexedRun.messages = plan.messages; + indexedRun.slots.reserve(run.size()); + for (const MaterializationRunSlot& slot : run) { + PackedScalarRunSlot indexedSlot; + indexedSlot.keys = getMaterializationRunSlotOutputKeys(slot, plan.resultIndex); + indexedRun.slots.push_back(std::move(indexedSlot)); + } + + state.availableValues.recordIndexedBatchRun(std::move(indexedRun)); + return success(); +} + +LogicalResult appendBatchRunReceives(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef run, + const BatchRunSendPlan& plan, + Type fragmentType, + Location loc) { + MaterializedClass& targetClass = state.classes[plan.destinationClass]; + + if (!targetClass.isBatch) + return appendPackedScalarRunReceives(state, sourceClass, run, plan, fragmentType, loc); + return recordIndexedBatchRunReceives(state, run, plan, fragmentType); +} + +LogicalResult materializeBatchClassRun(MaterializerState& state, + MaterializedClass& targetClass, + SlotId startSlot, + ArrayRef run) { + assert(targetClass.isBatch && "batch-target run materialization expects a materialized batch class"); + assert(!run.empty() && "expected non-empty batch-target run"); + + if (!canCompactBatchClassRun(state, targetClass, run)) + return failure(); + + markMaterializationRunSlots(state, targetClass.id, startSlot, run); + + SmallVector groups = groupBatchRunOutputsByDestination(state, run); + + auto sourceBatch = cast(run.front().peers.front().op); + SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); + Location loc = sourceBatch.getLoc(); + bool constantProjectionSlotIndex = requiresConstantProjectionSlotIndex(state, targetClass, sourceBatch); + + for (const OutputDestinationGroup& group : groups) { + SmallVector sendPlans; + if (failed(buildBatchRunSendPlans(state, targetClass, run, group, sendPlans))) + return failure(); + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + + if (constantProjectionSlotIndex) { + for (auto [slotIndex, slot] : llvm::enumerate(run)) { + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + + Value slotIndexValue = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(slotIndex)); + Value sourceLane = getOrCreateIndexConstant(state.constantFolder, targetClass.op, slot.peers.front().laneStart); + Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndexValue, loc); + + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + getScheduledChunkForLogicalInstance(state, run.front().peers.front()), + sourceLane, + group.resultIndices, + CloneIndexingContext {.runSlotIndex = slotIndexValue, + .projectionSlotIndex = slotIndexValue}); + if (failed(produced)) + return failure(); + + for (const BatchRunSendPlan& plan : sendPlans) { + auto resultIt = llvm::find(group.resultIndices, plan.resultIndex); + if (resultIt == group.resultIndices.end()) + return failure(); + + size_t groupOutputIndex = static_cast(std::distance(group.resultIndices.begin(), resultIt)); + appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc); + } + } + } else { + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {}, + [&](OpBuilder&, Location, Value slotIndex, ValueRange, SmallVectorImpl&) { + Value sourceLane = createBatchClassRunSourceLane(state, targetClass, run, slotIndex, loc); + Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); + + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + getScheduledChunkForLogicalInstance(state, run.front().peers.front()), + sourceLane, + group.resultIndices, + CloneIndexingContext {.runSlotIndex = slotIndex, .projectionSlotIndex = slotIndex}); + if (failed(produced)) + return failure(); + + for (const BatchRunSendPlan& plan : sendPlans) { + auto resultIt = llvm::find(group.resultIndices, plan.resultIndex); + if (resultIt == group.resultIndices.end()) + return failure(); + + size_t groupOutputIndex = static_cast(std::distance(group.resultIndices.begin(), resultIt)); + appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc); + } + return success(); + }); + if (failed(loop)) + return failure(); + } + + for (const BatchRunSendPlan& plan : sendPlans) { + if (plan.resultIndex >= fragmentTypes.size() || !fragmentTypes[plan.resultIndex]) + return failure(); + + if (failed(appendBatchRunReceives(state, targetClass, run, plan, fragmentTypes[plan.resultIndex], loc))) + return failure(); + } + + if (failed(registerMaterializedBatchRunHostOutputs(state, targetClass, run, group))) + return failure(); + } + + return success(); +} + +LogicalResult materializeInstanceSlot(MaterializerState& state, + const ComputeInstance& instance) { + auto cpuIt = state.schedule.computeToCpuMap.find(instance); + if (cpuIt == state.schedule.computeToCpuMap.end()) + return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); + auto logicalRangeIt = state.scheduledInstanceToLogicalSlots.find(instance); + if (logicalRangeIt == state.scheduledInstanceToLogicalSlots.end()) + return instance.op->emitError("schedule materialization expected logical slots for every compute instance"); + + ClassId classId = state.cpuToClass.lookup(cpuIt->second); + MaterializedClass& targetClass = state.classes[classId]; + + LogicalSlotRange logicalRange = logicalRangeIt->second; + SlotId startLogicalSlot = logicalRange.start; + while (startLogicalSlot < logicalRange.start + logicalRange.count + && state.materializedLogicalSlots.contains({classId, startLogicalSlot})) { + ++startLogicalSlot; + } + if (startLogicalSlot == logicalRange.start + logicalRange.count) + return success(); + + if (isa(instance.op)) { + FailureOr run = collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op); + + if (succeeded(run)) { + if (!targetClass.isBatch) + return materializeScalarBatchRun(state, targetClass, startLogicalSlot, *run); + + if (succeeded(materializeBatchClassRun(state, targetClass, startLogicalSlot, *run))) + return success(); + } + } + + if (!state.materializedLogicalSlots.insert({classId, startLogicalSlot}).second) + return success(); + + FailureOr> peers = + getMaterializationRunSlotPeers(state, targetClass, startLogicalSlot); + if (failed(peers)) + return instance.op->emitError("failed to collect peer compute instances for equivalence class logical slot"); + + Value projectionSlotIndex = getOrCreateIndexConstant( + state.constantFolder, targetClass.op, static_cast(startLogicalSlot - logicalRange.start)); + FailureOr> materializedOutputs = + cloneInstanceBody(state, + targetClass, + *peers, + CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = projectionSlotIndex}); + if (failed(materializedOutputs)) + return failure(); + + ArrayRef originalOutputs = getComputeInstanceOutputValuesCached(state, instance); + if (materializedOutputs->size() != originalOutputs.size()) + return instance.op->emitError("materialized output count does not match original compute instance output count"); + + for (auto [resultIndex, zipped] : llvm::enumerate(llvm::zip(*materializedOutputs, originalOutputs))) { + Value materializedOutput = std::get<0>(zipped); + Value originalOutput = std::get<1>(zipped); + MaterializationRunSlot slot; + slot.peers = *peers; + SmallVector keys = getMaterializationRunSlotOutputKeys(slot, resultIndex); + if (failed(emitOutputFanout(state, targetClass, keys, materializedOutput, originalOutput, instance.op->getLoc()))) + return failure(); + } + + return success(); +} + +FailureOr createReceiveConcatLoop(MaterializerState& state, + MaterializedClass& targetClass, + RankedTensorType concatType, + RankedTensorType fragmentType, + const MessageVector& messages, + Location loc) { + assert(succeeded(messages.verify(targetClass.op)) && "message metadata is inconsistent"); + assert(!messages.empty() && "expected at least one receive"); + + Operation* insertionPoint = targetClass.body->getTerminator(); + state.rewriter.setInsertionPoint(insertionPoint); + Value init = + tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult(); + return emitIndexedFragmentInsertLoop( + state, + targetClass, + init, + static_cast(messages.size()), + [&](Value index) -> FailureOr { + Value channelId = createIndexedChannelId(state, targetClass.op, messages, index, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, messages, index, loc); + Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, messages, index, loc); + return SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); + }, + [&](Value index) -> FailureOr { + return scaleIndexByDim0SizeInClass(state, targetClass, index, fragmentType.getDimSize(0), loc); + }, + loc); +} + + +std::optional getDirectCommunicationOrderKey(Operation* op) { + if (!op) + return std::nullopt; + + Value channelId; + Value sourceCoreId; + Value targetCoreId; + if (auto send = dyn_cast(op)) { + channelId = send.getChannelId(); + sourceCoreId = send.getSourceCoreId(); + targetCoreId = send.getTargetCoreId(); + } + else if (auto receive = dyn_cast(op)) { + channelId = receive.getChannelId(); + sourceCoreId = receive.getSourceCoreId(); + targetCoreId = receive.getTargetCoreId(); + } + else { + return std::nullopt; + } + + auto channel = getConstantIndexValue(channelId); + auto source = getConstantIndexValue(sourceCoreId); + auto target = getConstantIndexValue(targetCoreId); + if (!channel || !source || !target) + return std::nullopt; + + return computeBlockingCommunicationOrderKey( + static_cast(*source), static_cast(*target), *channel); +} + +std::optional getScalarCommunicationOrderKey(Operation* op) { + if (!op) + return std::nullopt; + if (auto order = op->getAttrOfType(kRaptorCommOrderAttr)) + return order.getInt(); + if (auto directOrder = getDirectCommunicationOrderKey(op)) + return directOrder; + if (auto channel = op->getAttrOfType(kRaptorMinChannelIdAttr)) + return channel.getInt(); + return std::nullopt; +} + +bool isReorderableScalarCommunication(Operation* op) { + if (!getScalarCommunicationOrderKey(op).has_value()) + return false; + + // The global-order repair is intentionally conservative: it may reorder + // send-side projections, but it must not move receives or any other + // communication op that defines SSA values. Moving a receive after one of + // its users breaks MLIR dominance; moving it before the source can produce + // the payload can also create a receive/receive deadlock. Receives therefore + // have to be placed correctly by the materializer when they are created. + // Direct spat.channel_send operations are included even when they were not + // produced by appendScalarSendLoop and therefore do not carry raptor.* + // attributes yet. This is needed for large scalar-to-scalar payload transfers + // that must be hoisted before reciprocal receives. + return isa(op) || (op->getNumResults() == 0 && op->hasAttr(kRaptorMinChannelIdAttr)); +} + +Operation* getLaterOperationInBlock(Operation* lhs, Operation* rhs) { + if (!lhs) + return rhs; + if (!rhs) + return lhs; + return lhs->isBeforeInBlock(rhs) ? rhs : lhs; +} + +Operation* getNextInsertionPointAfter(Operation* op, Block& block) { + if (!op) + return &block.front(); + Operation* next = op->getNextNode(); + return next ? next : block.getTerminator(); +} + +bool hasConstantRoutingOperands(SpatChannelSendOp send) { + return getConstantIndexValue(send.getChannelId()).has_value() + && getConstantIndexValue(send.getSourceCoreId()).has_value() + && getConstantIndexValue(send.getTargetCoreId()).has_value(); +} + +Operation* getLatestSameBlockOperandDefinition(Operation* root, Block& block) { + Operation* latest = nullptr; + + auto consider = [&](Value value) { + Operation* definingOp = value.getDefiningOp(); + if (!definingOp || definingOp->getBlock() != &block || definingOp == root) + return; + latest = getLaterOperationInBlock(latest, definingOp); + }; + + // For direct sends with constant routing operands, only the payload is a real + // scheduling dependency. The channel/source/target constants can be + // rematerialized at the new insertion point. Treating those constants as hard + // dependencies prevents the repair from hoisting a ready send above an early + // receive, which is exactly the receive/receive deadlock pattern reported by + // the static communication checker. + if (auto send = dyn_cast(root)) { + if (hasConstantRoutingOperands(send)) { + consider(send.getInput()); + return latest; + } + } + + for (Value operand : root->getOperands()) + consider(operand); + + for (Region& region : root->getRegions()) { + region.walk([&](Operation* nested) { + if (nested == root) + return; + for (Value operand : nested->getOperands()) + consider(operand); + }); + } + + return latest; +} + +void rematerializeDirectSendRoutingConstantsAt(MaterializerState& state, + SpatChannelSendOp send, + Operation* insertionPoint) { + if (!send || !insertionPoint || !hasConstantRoutingOperands(send)) + return; + + auto channel = getConstantIndexValue(send.getChannelId()); + auto source = getConstantIndexValue(send.getSourceCoreId()); + auto target = getConstantIndexValue(send.getTargetCoreId()); + if (!channel || !source || !target) + return; + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPoint(insertionPoint); + Location loc = send.getLoc(); + Value newChannel = arith::ConstantIndexOp::create(state.rewriter, loc, *channel); + Value newSource = arith::ConstantIndexOp::create(state.rewriter, loc, *source); + Value newTarget = arith::ConstantIndexOp::create(state.rewriter, loc, *target); + send->setOperand(0, newChannel); + send->setOperand(1, newSource); + send->setOperand(2, newTarget); +} + +LogicalResult reorderScalarClassCommunicationByGlobalOrder(MaterializerState& state, + MaterializedClass& materializedClass) { + if (materializedClass.isBatch) + return success(); + + Block& block = *materializedClass.body; + Operation* terminator = block.getTerminator(); + SmallVector communicationOps; + for (Operation& op : block) { + if (&op == terminator) + break; + if (isReorderableScalarCommunication(&op)) + communicationOps.push_back(&op); + } + + if (communicationOps.size() < 2) + return success(); + + llvm::stable_sort(communicationOps, [](Operation* lhs, Operation* rhs) { + std::optional lhsOrder = getScalarCommunicationOrderKey(lhs); + std::optional rhsOrder = getScalarCommunicationOrderKey(rhs); + if (lhsOrder != rhsOrder) + return lhsOrder.value_or(std::numeric_limits::max()) + < rhsOrder.value_or(std::numeric_limits::max()); + return lhs->isBeforeInBlock(rhs); + }); + + Operation* lastPlacedCommunication = nullptr; + for (Operation* communication : communicationOps) { + if (communication->getBlock() != &block) + return materializedClass.op->emitError("scalar communication global-order repair saw a moved operation"); + + Operation* dependency = getLatestSameBlockOperandDefinition(communication, block); + Operation* anchor = getLaterOperationInBlock(lastPlacedCommunication, dependency); + Operation* insertionPoint = getNextInsertionPointAfter(anchor, block); + + if (insertionPoint != communication && communication->getNextNode() != insertionPoint) { + if (auto send = dyn_cast(communication)) + rematerializeDirectSendRoutingConstantsAt(state, send, insertionPoint); + communication->moveBefore(insertionPoint); + } + + lastPlacedCommunication = communication; + } + + return success(); +} + +LogicalResult reorderScalarCommunicationsByGlobalOrder(MaterializerState& state) { + for (MaterializedClass& materializedClass : state.classes) + if (failed(reorderScalarClassCommunicationByGlobalOrder(state, materializedClass))) + return failure(); + return success(); +} + + +Operation* getEarliestOperationInBlock(Operation* lhs, Operation* rhs) { + if (!lhs) + return rhs; + if (!rhs) + return lhs; + return lhs->isBeforeInBlock(rhs) ? lhs : rhs; +} + +Operation* getTopLevelOperationInBlock(Operation* op, Block& block) { + for (Operation* current = op; current; current = current->getParentOp()) { + if (current->getBlock() == &block) + return current; + } + return nullptr; +} + +Operation* findEarliestTopLevelUse(Operation* producer, Block& block) { + Operation* earliest = nullptr; + for (Value result : producer->getResults()) { + for (Operation* user : result.getUsers()) { + Operation* topLevelUser = getTopLevelOperationInBlock(user, block); + if (!topLevelUser || topLevelUser == producer) + continue; + earliest = getEarliestOperationInBlock(earliest, topLevelUser); + } + } + return earliest; +} + +LogicalResult sinkScalarReceivesToFirstUse(MaterializerState& state) { + for (MaterializedClass& materializedClass : state.classes) { + if (materializedClass.isBatch) + continue; + + Block& block = *materializedClass.body; + Operation* terminator = block.getTerminator(); + SmallVector receives; + for (Operation& op : block) { + if (&op == terminator) + break; + if (isa(&op)) + receives.push_back(&op); + } + + for (Operation* receive : receives) { + if (receive->getBlock() != &block) + continue; + + Operation* firstUse = findEarliestTopLevelUse(receive, block); + if (!firstUse || firstUse == receive || firstUse->getBlock() != &block) + continue; + + if (!receive->isBeforeInBlock(firstUse)) + continue; + + if (receive->getNextNode() == firstUse) + continue; + + receive->setAttr("raptor.receive_sunk_to_first_use", UnitAttr::get(receive->getContext())); + receive->moveBefore(firstUse); + } + } + return success(); +} + +void replaceHostUses(MaterializerState& state) { + for (const auto& [oldValue, replacement] : state.hostReplacements) + replaceLiveExternalUses(oldValue, replacement, state.oldComputeOps); +} + +LogicalResult eraseOldComputeOps(MaterializerState& state) { + DenseSet seen; + for (const ComputeInstance& instance : state.schedule.dominanceOrderCompute) { + if (!seen.insert(instance.op).second) + continue; + instance.op->dropAllUses(); + instance.op->erase(); + } + return success(); +} + +} // namespace + +LogicalResult +MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) { + if (schedule.dominanceOrderCompute.empty()) + return success(); + + MaterializerState state(func, schedule, nextChannelId); + if (failed(buildMaterializationWorkStreams(state))) + return failure(); + if (failed(buildMaterializationClassesFromScheduleEquivalence(state))) + return failure(); + if (failed(verifyScheduleEquivalenceMatchesLogicalStreams(state))) + return failure(); + if (state.classes.empty()) + return success(); + + if (failed(collectHostOutputs(state))) + return failure(); + if (failed(createEmptyMaterializedOps(state))) + return failure(); + if (failed(collectProducerDestinations(state))) + return failure(); + if (failed(collectProjectedTransfers(state))) + return failure(); + + for (const ComputeInstance& instance : schedule.dominanceOrderCompute) + if (failed(materializeInstanceSlot(state, instance))) + return failure(); + + for (MaterializedClass& materializedClass : state.classes) + if (failed(localizeAllScheduledBodyCaptures(state, materializedClass))) + return failure(); + + if (failed(flushPendingProjectedHostReceives(state))) + return failure(); + + if (pimMaterializeScalarFanoutGlobalOrder) { + if (failed(sinkScalarReceivesToFirstUse(state))) + return failure(); + if (failed(reorderScalarCommunicationsByGlobalOrder(state))) + return failure(); + } + + if (failed(verifyMaterializedHostOutputs(state))) + return failure(); + + replaceHostUses(state); + if (failed(eraseOldComputeOps(state))) + return failure(); + + LogicalResult _ = runRegionDCE(state.rewriter, state.func.getBody()); + (void) _; + + return success(); +} + +} // namespace spatial +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.orig b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.orig new file mode 100644 index 0000000..1244543 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.orig @@ -0,0 +1,7548 @@ +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/RegionUtils.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include +#include +#include +#include + +#include "MaterializeMergeSchedule.hpp" +#include "Scheduling/ComputeInstanceUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp" +#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" + +using namespace mlir; + +namespace onnx_mlir { +namespace spatial { +namespace { + +using CpuId = size_t; +using ClassId = size_t; +using SlotId = size_t; + +static FailureOr getCheckedCoreId(Operation* anchor, CpuId cpu, StringRef fieldName) { + return pim::checkedI32(static_cast(cpu), anchor, fieldName); +} + +static FailureOr> +getCheckedCoreIds(Operation* anchor, ArrayRef cpus, StringRef fieldName) { + SmallVector coreIds; + coreIds.reserve(cpus.size()); + for (CpuId cpu : cpus) { + auto checkedCoreId = getCheckedCoreId(anchor, cpu, fieldName); + if (failed(checkedCoreId)) + return failure(); + coreIds.push_back(*checkedCoreId); + } + return coreIds; +} + +struct MessageVector { + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + + size_t size() const { return channelIds.size(); } + bool empty() const { return channelIds.empty(); } + + LogicalResult verify(Operation* anchor) const { + if (channelIds.size() != sourceCoreIds.size() || channelIds.size() != targetCoreIds.size()) + return anchor->emitError("message metadata is inconsistent"); + return success(); + } + + void append(int64_t channelId, int32_t sourceCoreId, int32_t targetCoreId) { + channelIds.push_back(channelId); + sourceCoreIds.push_back(sourceCoreId); + targetCoreIds.push_back(targetCoreId); + } + + void append(ArrayRef channels, ArrayRef sources, ArrayRef targets) { + assert(channels.size() == sources.size() && "channel/source count mismatch"); + assert(channels.size() == targets.size() && "channel/target count mismatch"); + llvm::append_range(channelIds, channels); + llvm::append_range(sourceCoreIds, sources); + llvm::append_range(targetCoreIds, targets); + } + + MessageVector slice(size_t offset, size_t count) const { + MessageVector result; + result.append(ArrayRef(channelIds).slice(offset, count), + ArrayRef(sourceCoreIds).slice(offset, count), + ArrayRef(targetCoreIds).slice(offset, count)); + return result; + } +}; + +struct ProducerKey { + ComputeInstance instance; + size_t resultIndex = 0; + + bool operator==(const ProducerKey& other) const { + return instance == other.instance && resultIndex == other.resultIndex; + } +}; + +struct ProducerKeyInfo { + static ProducerKey getEmptyKey() { + return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max()}; + } + + static ProducerKey getTombstoneKey() { + return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max()}; + } + + static unsigned getHashValue(const ProducerKey& key) { + return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.instance), key.resultIndex); + } + + static bool isEqual(const ProducerKey& lhs, const ProducerKey& rhs) { return lhs == rhs; } +}; + +struct SameClassConsumerLookupKey { + Operation* sourceOp = nullptr; + size_t resultIndex = 0; + ClassId classId = 0; + + bool operator==(const SameClassConsumerLookupKey& other) const { + return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId; + } +}; + +struct SameClassConsumerLookupKeyInfo { + static SameClassConsumerLookupKey getEmptyKey() { + return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max(), + std::numeric_limits::max()}; + } + + static SameClassConsumerLookupKey getTombstoneKey() { + return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max(), + std::numeric_limits::max()}; + } + + static unsigned getHashValue(const SameClassConsumerLookupKey& key) { + return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.sourceOp), key.resultIndex, key.classId); + } + + static bool isEqual(const SameClassConsumerLookupKey& lhs, const SameClassConsumerLookupKey& rhs) { + return lhs == rhs; + } +}; + +struct WholeBatchAssemblyLookupKey { + Operation* sourceOp = nullptr; + size_t resultIndex = 0; + ClassId classId = 0; + + bool operator==(const WholeBatchAssemblyLookupKey& other) const { + return sourceOp == other.sourceOp && resultIndex == other.resultIndex && classId == other.classId; + } +}; + +struct WholeBatchAssemblyLookupKeyInfo { + static WholeBatchAssemblyLookupKey getEmptyKey() { + return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max(), + std::numeric_limits::max()}; + } + + static WholeBatchAssemblyLookupKey getTombstoneKey() { + return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max(), + std::numeric_limits::max()}; + } + + static unsigned getHashValue(const WholeBatchAssemblyLookupKey& key) { + return llvm::hash_combine(llvm::DenseMapInfo::getHashValue(key.sourceOp), key.resultIndex, key.classId); + } + + static bool isEqual(const WholeBatchAssemblyLookupKey& lhs, const WholeBatchAssemblyLookupKey& rhs) { + return lhs == rhs; + } +}; + +using ClassSlotKey = std::pair; + +struct MaterializedClass { + ClassId id = 0; + SmallVector cpus; + Operation* op = nullptr; + Block* body = nullptr; + bool isBatch = false; + + DenseMap cpuToLane; + SmallVector weights; + SmallVector inputs; + SmallVector hostOutputs; + DenseMap weightArgs; + DenseMap inputArgs; + DenseMap hostOutputToResultIndex; +}; + +struct PackedScalarRunSlot { + SmallVector keys; +}; + +enum class PackedScalarRunKind { + Materialized, + DeferredReceive, + DeferredLocalCompute +}; + +struct PackedScalarRunValue { + ClassId targetClass = 0; + Operation* sourceOp = nullptr; + size_t resultIndex = 0; + PackedScalarRunKind kind = PackedScalarRunKind::Materialized; + + Value packed; + + RankedTensorType fragmentType; + SmallVector slots; + MessageVector messages; +}; + +struct IndexedBatchRunValue { + ClassId targetClass = 0; + Operation* sourceOp = nullptr; + size_t resultIndex = 0; + RankedTensorType fragmentType; + SmallVector slots; + MessageVector messages; +}; + +struct LogicalSlotRange { + SlotId start = 0; + SlotId count = 0; +}; + +struct MaterializationRunSlot { + SmallVector peers; +}; + +using MaterializationRun = SmallVector; + +struct OutputDestinationGroup { + SmallVector resultIndices; + SmallVector destinationClasses; +}; + +struct BatchRunSendPlan { + size_t resultIndex = 0; + ClassId destinationClass = 0; + MessageVector messages; +}; + +struct ProjectedBatchInputKey { + Operation* consumerOp = nullptr; + unsigned inputIndex = 0; + + bool operator==(const ProjectedBatchInputKey& other) const { + return consumerOp == other.consumerOp && inputIndex == other.inputIndex; + } +}; + +struct ProjectedBatchInputKeyInfo { + static ProjectedBatchInputKey getEmptyKey() { + return {llvm::DenseMapInfo::getEmptyKey(), std::numeric_limits::max()}; + } + + static ProjectedBatchInputKey getTombstoneKey() { + return {llvm::DenseMapInfo::getTombstoneKey(), std::numeric_limits::max()}; + } + + static unsigned getHashValue(const ProjectedBatchInputKey& key) { + return llvm::hash_combine(key.consumerOp, key.inputIndex); + } + + static bool isEqual(const ProjectedBatchInputKey& lhs, const ProjectedBatchInputKey& rhs) { return lhs == rhs; } +}; + +struct ProjectedFragmentLayout { + RankedTensorType fragmentType; + SmallVector fragmentShape; + unsigned fragmentsPerLogicalSlot = 1; + unsigned payloadFragmentCount = 1; + SmallVector loopLowerBounds; + SmallVector loopSteps; + SmallVector loopTripCounts; +}; + +struct ProjectedTransferDescriptor { + ProjectedBatchInputKey inputKey; + Operation* extractOp = nullptr; + + ProjectedFragmentLayout layout; + RankedTensorType payloadType; + SmallVector, 16> fragmentOffsets; + SmallVector, 4> fragmentOffsetsByDim; +}; + +struct ProjectedExtractReplacement { + Value payload; + ProjectedFragmentLayout layout; +}; + +struct PendingProjectedHostOutputFragment { + Value originalOutput; + ClassId sourceClass = 0; + Value fragment; + RankedTensorType fragmentType; + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + uint32_t sourceLane = 0; + Location loc; + + // When a materialized batch class is the source, the send must be emitted + // once with lane-indexed channel metadata. Finalization then only emits the + // matching scalar receive for each recorded fragment. Scalar sources keep + // the old behavior and emit their send during finalization. + bool sendAlreadyEmitted = false; + MessageVector messages; +}; + +struct CloneIndexingContext { + std::optional runSlotIndex; + std::optional projectionSlotIndex; +}; + +struct StaticProjectedLoopInfo { + BlockArgument iv; + int64_t lowerBound = 0; + int64_t step = 1; + int64_t tripCount = 1; +}; + +struct AffineProjectedInputSliceMatch { + tensor::ExtractSliceOp extract; + RankedTensorType sourceType; + RankedTensorType fragmentType; + SmallVector fragmentShape; + SmallVector offsets; + SmallVector loops; +}; + +struct MaterializerState; + +FailureOr materializeProjectedExtractReplacement(MaterializerState& state, + MaterializedClass& targetClass, + tensor::ExtractSliceOp extract, + const ProjectedExtractReplacement& replacement, + std::optional projectionSlotIndex, + IRMapping* mapper = nullptr); +FailureOr rematerializeTensorValueInClass(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef context, + IRMapping* mapper = nullptr); +FailureOr materializeTensorValueForMaterializedClassUse(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef context, + std::optional producer = std::nullopt, + IRMapping* mapper = nullptr); +FailureOr localizeMaterializedClassOperand(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef tensorContext, + StringRef genericContext, + IRMapping* mapper = nullptr); +LogicalResult localizeCapturesInClonedOp(MaterializerState& state, + MaterializedClass& targetClass, + Operation& clonedOp, + IRMapping* mapper = nullptr); +LogicalResult localizeAllScheduledBodyCaptures(MaterializerState& state, MaterializedClass& targetClass); +bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch, + const AffineProjectedInputSliceMatch& match, + ProducerKey producer, + uint32_t consumerLane); +std::optional getProjectedInputSliceMatch(MaterializerState& state, + SpatComputeBatch batch, + unsigned inputIndex); + +class AvailableValueStore { +public: + struct ExactBatchFragmentRecord { + ProducerKey key; + Value value; + }; + + void record(ProducerKey key, ClassId classId, Value value) { + exactValues[key][classId] = value; + + auto batch = dyn_cast_or_null(key.instance.op); + if (!batch || key.instance.laneCount == 0) + return; + + WholeBatchAssemblyLookupKey lookupKey {batch.getOperation(), key.resultIndex, classId}; + SmallVector& bucket = exactBatchFragmentsByProducerResultClass[lookupKey]; + for (ExactBatchFragmentRecord& record : bucket) { + if (!(record.key == key)) + continue; + record.value = value; + return; + } + bucket.push_back({key, value}); + } + + void recordPackedRun(PackedScalarRunValue run) { + size_t runIndex = packedScalarRuns.size(); + packedScalarRuns.push_back(std::move(run)); + const PackedScalarRunValue& storedRun = packedScalarRuns[runIndex]; + WholeBatchAssemblyLookupKey lookupKey {storedRun.sourceOp, storedRun.resultIndex, storedRun.targetClass}; + packedRunsByProducerResultClass[lookupKey].push_back(runIndex); + } + void recordIndexedBatchRun(IndexedBatchRunValue run) { indexedBatchRuns.push_back(std::move(run)); } + + std::optional lookupExact(ProducerKey key, ClassId classId) const; + + std::optional lookup(MaterializerState& state, ProducerKey key, ClassId classId); + IndexedBatchRunValue* lookupIndexedBatchRun(ProducerKey key, ClassId classId); + + ArrayRef getPackedRunIndicesForWholeBatch(WholeBatchAssemblyLookupKey key) const { + auto it = packedRunsByProducerResultClass.find(key); + if (it == packedRunsByProducerResultClass.end()) + return {}; + return it->second; + } + + ArrayRef getExactFragmentsForWholeBatch(WholeBatchAssemblyLookupKey key) const { + auto it = exactBatchFragmentsByProducerResultClass.find(key); + if (it == exactBatchFragmentsByProducerResultClass.end()) + return {}; + return it->second; + } + + PackedScalarRunValue& getPackedRun(size_t index) { return packedScalarRuns[index]; } + +private: + std::optional lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId); + + DenseMap, ProducerKeyInfo> exactValues; + SmallVector packedScalarRuns; + SmallVector indexedBatchRuns; + DenseMap, WholeBatchAssemblyLookupKeyInfo> + exactBatchFragmentsByProducerResultClass; + DenseMap, WholeBatchAssemblyLookupKeyInfo> + packedRunsByProducerResultClass; +}; + +struct MaterializerState { + func::FuncOp func; + const MergeScheduleResult& schedule; + IRRewriter rewriter; + OperationFolder constantFolder; + int64_t& nextChannelId; + SmallVector classes; + DenseMap cpuToClass; + DenseMap> logicalInstancesByCpu; + DenseMap scheduledInstanceToLogicalSlots; + DenseMap logicalInstanceToScheduledChunk; + DenseSet materializedLogicalSlots; + + DenseMap, ProducerKeyInfo> producerDestClasses; + DenseMap, SameClassConsumerLookupKeyInfo> + sameClassConsumerIndex; + DenseMap projectedInputMatches; + DenseSet nonProjectedInputs; + DenseMap liveExternalUseCache; + DenseMap> batchOutputFragmentTypesCache; + DenseMap, llvm::DenseMapInfo> computeInstanceOutputsCache; + DenseMap, ProducerKeyInfo> projectedTransfers; + DenseMap> projectedExtractReplacements; + AvailableValueStore availableValues; + DenseMap hostReplacements; + DenseMap hostOutputOwners; + SmallVector pendingProjectedHostOutputFragments; + DenseSet oldComputeOps; + + MaterializerState(func::FuncOp func, + const MergeScheduleResult& schedule, + int64_t& nextChannelId) + : func(func), + schedule(schedule), + rewriter(func.getContext()), + constantFolder(func.getContext()), + nextChannelId(nextChannelId) {} +}; + +bool isConstantLike(Value value) { + Operation* definingOp = value.getDefiningOp(); + return definingOp && definingOp->hasTrait(); +} + +bool isInsideOldCompute(Operation* op, const DenseSet& oldComputeOps) { + for (Operation* current = op; current; current = current->getParentOp()) + if (oldComputeOps.contains(current)) + return true; + return false; +} + +bool hasLiveExternalUse(Value value, const DenseSet& oldComputeOps); +ArrayRef getComputeInstanceOutputValuesCached(MaterializerState& state, ComputeInstance instance); + +bool hasLiveExternalUseCached(MaterializerState& state, Value value) { + auto cached = state.liveExternalUseCache.find(value); + if (cached != state.liveExternalUseCache.end()) + return cached->second; + bool live = hasLiveExternalUse(value, state.oldComputeOps); + state.liveExternalUseCache[value] = live; + return live; +} + +std::optional getConstantFirstSliceOffset(tensor::ExtractSliceOp extract) { + if (extract.getMixedOffsets().empty()) + return std::nullopt; + + OpFoldResult offset = extract.getMixedOffsets().front(); + if (auto attr = dyn_cast(offset)) { + auto intAttr = dyn_cast(attr); + if (!intAttr || intAttr.getInt() < 0) + return std::nullopt; + return static_cast(intAttr.getInt()); + } + + auto value = cast(offset); + if (auto constantIndex = value.getDefiningOp()) { + if (constantIndex.value() < 0) + return std::nullopt; + return static_cast(constantIndex.value()); + } + + APInt constantValue; + if (matchPattern(value, m_ConstantInt(&constantValue))) { + if (constantValue.isNegative()) + return std::nullopt; + return static_cast(constantValue.getZExtValue()); + } + + return std::nullopt; +} + +ProducerKey +getBatchLaneProducerKey(SpatComputeBatch batch, uint32_t laneStart, uint32_t laneCount, size_t resultIndex) { + return { + {batch.getOperation(), laneStart, laneCount}, + resultIndex + }; +} + +ProducerKey getWholeBatchProducerKey(SpatComputeBatch batch, size_t resultIndex) { + return getBatchLaneProducerKey(batch, 0, static_cast(batch.getLaneCount()), resultIndex); +} + +bool isWholeBatchProducerKey(ProducerKey key) { + auto batch = dyn_cast_or_null(key.instance.op); + return batch && batch.getNumResults() != 0 && key.instance.laneStart == 0 + && key.instance.laneCount == static_cast(batch.getLaneCount()); +} + +std::optional getContiguousProducerRangeForKeys(ArrayRef keys) { + if (keys.empty()) + return std::nullopt; + + ProducerKey first = keys.front(); + auto batch = dyn_cast_or_null(first.instance.op); + if (!batch) + return std::nullopt; + + SmallVector sorted(keys.begin(), keys.end()); + llvm::sort(sorted, [](ProducerKey lhs, ProducerKey rhs) { + return std::tie(lhs.instance.laneStart, lhs.instance.laneCount, lhs.resultIndex) + < std::tie(rhs.instance.laneStart, rhs.instance.laneCount, rhs.resultIndex); + }); + + uint32_t laneStart = sorted.front().instance.laneStart; + uint32_t nextLane = laneStart; + for (ProducerKey key : sorted) { + if (key.instance.op != first.instance.op || key.resultIndex != first.resultIndex || key.instance.laneCount == 0) + return std::nullopt; + if (key.instance.laneStart != nextLane) + return std::nullopt; + nextLane += key.instance.laneCount; + } + + uint32_t laneCount = nextLane - laneStart; + if (laneStart + laneCount > static_cast(batch.getLaneCount())) + return std::nullopt; + + return getBatchLaneProducerKey(batch, laneStart, laneCount, first.resultIndex); +} + +WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(Operation* sourceOp, size_t resultIndex, ClassId classId) { + return {sourceOp, resultIndex, classId}; +} + +WholeBatchAssemblyLookupKey makeWholeBatchAssemblyLookupKey(ProducerKey key, ClassId classId) { + return makeWholeBatchAssemblyLookupKey(key.instance.op, key.resultIndex, classId); +} + +FailureOr getPackedBatchTensorType(Type laneType, size_t laneCount) { + auto tensorType = dyn_cast(laneType); + if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0) + return failure(); + + SmallVector shape(tensorType.getShape()); + shape[0] *= static_cast(laneCount); + return RankedTensorType::get(shape, tensorType.getElementType(), tensorType.getEncoding()); +} + +LogicalResult verifyPackableFragmentType(Operation* anchor, Type fragmentType, size_t count, StringRef message) { + if (failed(getPackedBatchTensorType(fragmentType, count))) + return anchor->emitError(message); + return success(); +} + +ComputeInstance getScheduledChunkForLogicalInstance(MaterializerState& state, ComputeInstance logicalInstance) { + auto it = state.logicalInstanceToScheduledChunk.find(logicalInstance); + if (it != state.logicalInstanceToScheduledChunk.end()) + return it->second; + return logicalInstance; +} + +SmallVector +collectProducerKeysForDestinations(Value value, std::optional logicalConsumer = std::nullopt) { + // Destination collection works in the materializer's logical one-lane key domain. + // Whole-batch resultful producers are expanded into per-lane producer keys here. + SmallVector keys; + Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return keys; + + while (auto extract = dyn_cast(definingOp)) { + Value source = extract.getSource(); + auto batch = dyn_cast_or_null(source.getDefiningOp()); + if (batch && batch.getNumResults() != 0) { + auto result = dyn_cast(source); + if (!result) + return {}; + + if (std::optional lane = getConstantFirstSliceOffset(extract)) { + if (*lane >= static_cast(batch.getLaneCount())) + return {}; + keys.push_back(getBatchLaneProducerKey(batch, *lane, 1, result.getResultNumber())); + return keys; + } + + if (logicalConsumer && isa(logicalConsumer->op)) { + keys.push_back(getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber())); + return keys; + } + + return {}; + } + + value = source; + definingOp = value.getDefiningOp(); + if (!definingOp) + return {}; + } + + if (auto compute = dyn_cast(definingOp)) { + auto result = dyn_cast(value); + if (!result) + return {}; + keys.push_back({ + {compute.getOperation(), 0, 1}, + result.getResultNumber() + }); + return keys; + } + + if (auto batch = dyn_cast(definingOp)) { + auto result = dyn_cast(value); + if (!result) + return {}; + + if (batch.getNumResults() != 0) { + if (logicalConsumer && isa(logicalConsumer->op)) { + keys.push_back(getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber())); + return keys; + } + + for (uint32_t lane = 0; lane < static_cast(batch.getLaneCount()); ++lane) + keys.push_back(getBatchLaneProducerKey(batch, lane, 1, result.getResultNumber())); + return keys; + } + + ComputeInstance chunk = getBatchChunkForLane(batch, result.getResultNumber()); + keys.push_back({chunk, static_cast(result.getResultNumber() - chunk.laneStart)}); + return keys; + } + + return keys; +} + +std::optional getInputRequestProducerKey(Value value, + std::optional logicalConsumer = std::nullopt) { + // Input resolution may request a whole-batch key for scalar consumers that read + // a complete resultful compute_batch value. + Operation* definingOp = value.getDefiningOp(); + if (!definingOp) + return std::nullopt; + + while (auto extract = dyn_cast(definingOp)) { + Value source = extract.getSource(); + auto batch = dyn_cast_or_null(source.getDefiningOp()); + if (batch && batch.getNumResults() != 0) { + auto result = dyn_cast(source); + if (!result) + return std::nullopt; + + if (std::optional lane = getConstantFirstSliceOffset(extract)) + return getBatchLaneProducerKey(batch, *lane, 1, result.getResultNumber()); + + if (logicalConsumer && isa(logicalConsumer->op)) + return getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber()); + + return std::nullopt; + } + + value = source; + definingOp = value.getDefiningOp(); + if (!definingOp) + return std::nullopt; + } + + if (auto compute = dyn_cast(definingOp)) { + auto result = dyn_cast(value); + if (!result) + return std::nullopt; + return ProducerKey { + {compute.getOperation(), 0, 1}, + result.getResultNumber() + }; + } + + if (auto batch = dyn_cast(definingOp)) { + auto result = dyn_cast(value); + if (!result) + return std::nullopt; + + if (batch.getNumResults() != 0) { + if (logicalConsumer && isa(logicalConsumer->op)) + return getBatchLaneProducerKey(batch, logicalConsumer->laneStart, 1, result.getResultNumber()); + return getWholeBatchProducerKey(batch, result.getResultNumber()); + } + + return ProducerKey {getBatchChunkForLane(batch, result.getResultNumber()), 0}; + } + + return std::nullopt; +} + +std::optional getWholeBatchProducerKeyForDirectBatchResult(Value value) { + auto result = dyn_cast(value); + if (!result) + return std::nullopt; + + auto batch = dyn_cast_or_null(result.getOwner()); + if (!batch || batch.getNumResults() == 0) + return std::nullopt; + + return getWholeBatchProducerKey(batch, result.getResultNumber()); +} + +bool canUseProjectedLaneInput(MaterializerState& state, + SpatComputeBatch consumerBatch, + unsigned inputIndex, + Value input, + ComputeInstance logicalConsumer) { + auto producerResult = dyn_cast(input); + if (!producerResult) + return false; + + auto producerBatch = dyn_cast_or_null(producerResult.getOwner()); + if (!producerBatch || producerBatch.getNumResults() == 0) + return false; + + std::optional match = + getProjectedInputSliceMatch(state, consumerBatch, inputIndex); + if (!match) + return false; + + ProducerKey laneProducer = + getBatchLaneProducerKey(producerBatch, logicalConsumer.laneStart, 1, producerResult.getResultNumber()); + return isProjectedInputSliceCompatibleWithProducerFragments( + consumerBatch, *match, laneProducer, logicalConsumer.laneStart); +} + +SmallVector collectProducerKeysForBatchInputDestinations(MaterializerState& state, + SpatComputeBatch consumerBatch, + unsigned inputIndex, + Value input, + ComputeInstance logicalConsumer) { + if (std::optional wholeBatchProducer = getWholeBatchProducerKeyForDirectBatchResult(input)) { + if (!canUseProjectedLaneInput(state, consumerBatch, inputIndex, input, logicalConsumer)) { + auto producerBatch = cast(wholeBatchProducer->instance.op); + SmallVector keys; + for (uint32_t lane = 0; lane < static_cast(producerBatch.getLaneCount()); ++lane) + keys.push_back(getBatchLaneProducerKey(producerBatch, lane, 1, wholeBatchProducer->resultIndex)); + return keys; + } + } + + return collectProducerKeysForDestinations(input, logicalConsumer); +} + +class CpuUnionFind { +public: + void insert(CpuId cpu) { parent.try_emplace(cpu, cpu); } + + CpuId find(CpuId cpu) { + insert(cpu); + CpuId p = parent.lookup(cpu); + if (p == cpu) + return cpu; + CpuId root = find(p); + parent[cpu] = root; + return root; + } + + void unite(CpuId lhs, CpuId rhs) { + CpuId lhsRoot = find(lhs); + CpuId rhsRoot = find(rhs); + if (lhsRoot == rhsRoot) + return; + if (rhsRoot < lhsRoot) + std::swap(lhsRoot, rhsRoot); + parent[rhsRoot] = lhsRoot; + } + +private: + DenseMap parent; +}; + +LogicalResult buildMaterializationWorkStreams(MaterializerState& state) { + DenseMap> scheduledInstancesByCpu; + for (const auto& [instance, cpu] : state.schedule.computeToCpuMap) { + state.oldComputeOps.insert(instance.op); + scheduledInstancesByCpu[cpu].push_back(instance); + state.logicalInstancesByCpu.try_emplace(cpu); + } + + for (auto& [cpu, scheduledInstances] : scheduledInstancesByCpu) { + llvm::sort(scheduledInstances, [&](const ComputeInstance& lhs, const ComputeInstance& rhs) { + auto lhsIt = state.schedule.computeToCpuSlotMap.find(lhs); + auto rhsIt = state.schedule.computeToCpuSlotMap.find(rhs); + assert(lhsIt != state.schedule.computeToCpuSlotMap.end() && "missing scheduler slot"); + assert(rhsIt != state.schedule.computeToCpuSlotMap.end() && "missing scheduler slot"); + return lhsIt->second < rhsIt->second; + }); + + SmallVector& logicalInstances = state.logicalInstancesByCpu[cpu]; + SlotId logicalSlot = 0; + for (const ComputeInstance& instance : scheduledInstances) { + LogicalSlotRange range {logicalSlot, 1}; + if (isa(instance.op)) + range.count = instance.laneCount; + + state.scheduledInstanceToLogicalSlots[instance] = range; + + if (isa(instance.op)) { + for (uint32_t localLane = 0; localLane < instance.laneCount; ++localLane, ++logicalSlot) { + uint32_t logicalLane = instance.laneStart + localLane; + ComputeInstance logicalInstance {instance.op, logicalLane, 1}; + logicalInstances.push_back(logicalInstance); + state.logicalInstanceToScheduledChunk[logicalInstance] = instance; + } + continue; + } + + logicalInstances.push_back(instance); + ++logicalSlot; + } + } + + return success(); +} + +LogicalResult buildMaterializationClassesFromScheduleEquivalence(MaterializerState& state) { + DenseSet usedCpus; + for (const auto& entry : state.schedule.cpuToLastComputeMap) + usedCpus.insert(entry.first); + for (const auto& entry : state.schedule.computeToCpuMap) + usedCpus.insert(entry.second); + + CpuUnionFind unionFind; + for (CpuId cpu : usedCpus) + unionFind.insert(cpu); + + for (const auto& [cpu, equivalentCpus] : state.schedule.equivalentClass) { + if (!usedCpus.contains(cpu)) + continue; + for (CpuId equivalentCpu : equivalentCpus) + if (usedCpus.contains(equivalentCpu)) + unionFind.unite(cpu, equivalentCpu); + } + + DenseMap> groupsByRoot; + for (CpuId cpu : usedCpus) + groupsByRoot[unionFind.find(cpu)].push_back(cpu); + + SmallVector roots; + roots.reserve(groupsByRoot.size()); + for (const auto& entry : groupsByRoot) + roots.push_back(entry.first); + llvm::sort(roots); + + state.classes.reserve(roots.size()); + for (CpuId root : roots) { + MaterializedClass materializedClass; + materializedClass.id = state.classes.size(); + materializedClass.cpus = groupsByRoot.lookup(root); + llvm::sort(materializedClass.cpus); + materializedClass.isBatch = materializedClass.cpus.size() > 1; + for (auto [lane, cpu] : llvm::enumerate(materializedClass.cpus)) { + materializedClass.cpuToLane[cpu] = static_cast(lane); + state.cpuToClass[cpu] = materializedClass.id; + } + state.classes.push_back(std::move(materializedClass)); + } + + return success(); +} + +LogicalResult verifyScheduleEquivalenceMatchesLogicalStreams(MaterializerState& state) { + for (const MaterializedClass& materializedClass : state.classes) { + if (materializedClass.cpus.empty()) + continue; + + auto referenceIt = state.logicalInstancesByCpu.find(materializedClass.cpus.front()); + if (referenceIt == state.logicalInstancesByCpu.end()) + return state.func.emitError("missing logical stream for materialized class reference CPU"); + + ArrayRef referenceStream(referenceIt->second); + for (CpuId cpu : materializedClass.cpus) { + auto streamIt = state.logicalInstancesByCpu.find(cpu); + if (streamIt == state.logicalInstancesByCpu.end()) + return state.func.emitError("missing logical stream for materialized class CPU"); + + ArrayRef stream(streamIt->second); + if (stream.size() != referenceStream.size()) + return state.func.emitError("materialized class CPUs have mismatched logical stream lengths"); + + for (auto [slot, zipped] : llvm::enumerate(llvm::zip(referenceStream, stream))) { + const ComputeInstance& referenceInstance = std::get<0>(zipped); + const ComputeInstance& currentInstance = std::get<1>(zipped); + if (referenceInstance.op != currentInstance.op) + return state.func.emitError("materialized class logical slot source op mismatch"); + if (isa(referenceInstance.op) != isa(currentInstance.op)) + return state.func.emitError("materialized class logical slot batch/scalar mismatch"); + (void) slot; + } + } + } + + return success(); +} + +LogicalResult forEachLogicalConsumerInMaterializationOrder( + MaterializerState& state, + llvm::function_ref + callback) { + for (const ComputeInstance& scheduledInstance : state.schedule.dominanceOrderCompute) { + auto cpuIt = state.schedule.computeToCpuMap.find(scheduledInstance); + if (cpuIt == state.schedule.computeToCpuMap.end()) + return scheduledInstance.op->emitError("missing CPU assignment for scheduled logical-slot iteration"); + + auto rangeIt = state.scheduledInstanceToLogicalSlots.find(scheduledInstance); + if (rangeIt == state.scheduledInstanceToLogicalSlots.end()) + return scheduledInstance.op->emitError("missing logical slot range for scheduled logical-slot iteration"); + + CpuId cpu = cpuIt->second; + ClassId classId = state.cpuToClass.lookup(cpu); + LogicalSlotRange range = rangeIt->second; + auto streamIt = state.logicalInstancesByCpu.find(cpu); + if (streamIt == state.logicalInstancesByCpu.end()) + return scheduledInstance.op->emitError("missing logical stream for CPU"); + for (SlotId logicalSlot = range.start; logicalSlot < range.start + range.count; ++logicalSlot) { + if (logicalSlot >= streamIt->second.size()) + return scheduledInstance.op->emitError("missing logical slot materialization instance"); + if (failed(callback(cpu, classId, scheduledInstance, streamIt->second[logicalSlot], logicalSlot))) + return failure(); + } + } + + return success(); +} + +bool isTerminalHostBatchOutput(Value output, const DenseSet& oldComputeOps); + +LogicalResult collectHostOutputs(MaterializerState& state) { + DenseSet seenOutputs; + SmallVector orderedOutputs; + DenseMap preferredOwners; + + for (const ComputeInstance& instance : state.schedule.dominanceOrderCompute) { + auto cpuIt = state.schedule.computeToCpuMap.find(instance); + if (cpuIt == state.schedule.computeToCpuMap.end()) + return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); + + ClassId classId = state.cpuToClass.lookup(cpuIt->second); + MaterializedClass& materializedClass = state.classes[classId]; + for (Value output : getComputeInstanceOutputValuesCached(state, instance)) { + if (!hasLiveExternalUseCached(state, output)) + continue; + + if (seenOutputs.insert(output).second) { + orderedOutputs.push_back(output); + preferredOwners[output] = classId; + continue; + } + + auto batch = dyn_cast_or_null(output.getDefiningOp()); + if (!batch || batch.getNumResults() == 0) + continue; + + ClassId currentOwner = preferredOwners.lookup(output); + bool terminalHost = isTerminalHostBatchOutput(output, state.oldComputeOps); + if (terminalHost) { + // Terminal resultful batch outputs are still published through scalar + // host-output slots unless the materialized batch class owns the output + // directly. Selecting an arbitrary batch class as the host owner would + // require a projection-aware batch publication path, which the + // materializer does not currently implement. + if (state.classes[currentOwner].isBatch && !materializedClass.isBatch) + preferredOwners[output] = classId; + continue; + } + + if (state.classes[currentOwner].isBatch && !materializedClass.isBatch) + preferredOwners[output] = classId; + } + } + + for (MaterializedClass& materializedClass : state.classes) { + materializedClass.hostOutputs.clear(); + materializedClass.hostOutputToResultIndex.clear(); + } + state.hostOutputOwners.clear(); + + for (Value output : orderedOutputs) { + ClassId ownerClassId = preferredOwners.lookup(output); + MaterializedClass& ownerClass = state.classes[ownerClassId]; + ownerClass.hostOutputToResultIndex[output] = ownerClass.hostOutputs.size(); + ownerClass.hostOutputs.push_back(output); + state.hostOutputOwners[output] = ownerClassId; + } + + return success(); +} + +LogicalResult createEmptyMaterializedOps(MaterializerState& state) { + Location loc = state.func.getLoc(); + Block& funcBlock = state.func.getBody().front(); + + Operation* firstOldCompute = nullptr; + for (Operation& op : funcBlock) { + if (state.oldComputeOps.contains(&op)) { + firstOldCompute = &op; + break; + } + } + + if (firstOldCompute) + state.rewriter.setInsertionPoint(firstOldCompute); + else + state.rewriter.setInsertionPointToStart(&funcBlock); + + for (MaterializedClass& materializedClass : state.classes) { + SmallVector resultTypes; + resultTypes.reserve(materializedClass.hostOutputs.size()); + for (Value output : materializedClass.hostOutputs) + resultTypes.push_back(output.getType()); + + if (!materializedClass.isBatch) { + auto compute = SpatScheduledCompute::create(state.rewriter, loc, TypeRange(resultTypes), ValueRange {}, ValueRange {}); + compute.getProperties().setOperandSegmentSizes({0, 0}); + auto coreIdAttr = + pim::getCheckedI32Attr(state.rewriter, state.func, materializedClass.cpus.front(), "materialized core id"); + if (failed(coreIdAttr)) + return failure(); + compute->setAttr(onnx_mlir::kCoreIdAttrName, *coreIdAttr); + Block* body = state.rewriter.createBlock(&compute.getBody()); + state.rewriter.setInsertionPointToEnd(body); + SmallVector placeholderOutputs; + placeholderOutputs.reserve(resultTypes.size()); + for (Type resultType : resultTypes) { + auto tensorType = dyn_cast(resultType); + if (!tensorType || !tensorType.hasStaticShape()) { + compute.emitOpError("host-facing materialized compute results must be static ranked tensors"); + return failure(); + } + placeholderOutputs.push_back( + tensor::EmptyOp::create(state.rewriter, loc, tensorType.getShape(), tensorType.getElementType()).getResult()); + } + SpatYieldOp::create(state.rewriter, loc, ValueRange(placeholderOutputs)); + materializedClass.op = compute.getOperation(); + materializedClass.body = body; + state.rewriter.setInsertionPointAfter(compute.getOperation()); + continue; + } + + auto batchLaneCountAttr = pim::getCheckedI32Attr( + state.rewriter, state.func, materializedClass.cpus.size(), "materialized batch lane count"); + if (failed(batchLaneCountAttr)) + return failure(); + auto batch = SpatScheduledComputeBatch::create( + state.rewriter, loc, TypeRange(resultTypes), *batchLaneCountAttr, ValueRange {}, ValueRange {}); + batch.getProperties().setOperandSegmentSizes({0, 0}); + auto coreIds = getCheckedCoreIds(state.func, materializedClass.cpus, "materialized batch core id"); + if (failed(coreIds)) + return failure(); + batch->setAttr(onnx_mlir::kCoreIdsAttrName, state.rewriter.getDenseI32ArrayAttr(*coreIds)); + + SmallVector blockArgTypes {state.rewriter.getIndexType()}; + SmallVector blockArgLocs {loc}; + llvm::append_range(blockArgTypes, resultTypes); + blockArgLocs.append(resultTypes.size(), loc); + Block* body = + state.rewriter.createBlock(&batch.getBody(), batch.getBody().end(), TypeRange(blockArgTypes), blockArgLocs); + state.rewriter.setInsertionPointToEnd(body); + if (resultTypes.empty()) + SpatYieldOp::create(state.rewriter, loc, ValueRange {}); + else + SpatInParallelOp::create(state.rewriter, loc); + materializedClass.op = batch.getOperation(); + materializedClass.body = body; + state.rewriter.setInsertionPointAfter(batch.getOperation()); + } + + return success(); +} + +BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) { + auto it = materializedClass.weightArgs.find(weight); + if (it != materializedClass.weightArgs.end()) + return it->second; + + unsigned weightIndex = materializedClass.weights.size(); + materializedClass.weights.push_back(weight); + + if (auto compute = dyn_cast(materializedClass.op)) { + auto arg = compute.insertWeight(weightIndex, weight, weight.getLoc()); + assert(arg && "expected compute body while inserting a weight"); + materializedClass.weightArgs[weight] = std::get<1>(*arg); + return std::get<1>(*arg); + } + + auto batch = cast(materializedClass.op); + auto arg = batch.insertWeight(weightIndex, weight, weight.getLoc()); + assert(arg && "expected compute_batch body while inserting a weight argument"); + materializedClass.weightArgs[weight] = std::get<1>(*arg); + return std::get<1>(*arg); +} + +BlockArgument appendInput(MaterializerState& state, MaterializedClass& materializedClass, Value input) { + auto it = materializedClass.inputArgs.find(input); + if (it != materializedClass.inputArgs.end()) + return it->second; + + materializedClass.inputs.push_back(input); + if (auto compute = dyn_cast(materializedClass.op)) { + auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc()); + assert(arg && "expected compute body while inserting an input"); + materializedClass.inputArgs[input] = std::get<1>(*arg); + return std::get<1>(*arg); + } + if (auto compute = dyn_cast(materializedClass.op)) { + auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc()); + assert(arg && "expected compute_batch body while inserting an input argument"); + materializedClass.inputArgs[input] = std::get<1>(*arg); + return std::get<1>(*arg); + } + llvm_unreachable("Cannot reach here"); +} + +// ----------------------------------------------------------------------------- +// Materialized-class value localization helpers. +// ----------------------------------------------------------------------------- + +Region* getParentRegion(Value value) { + if (auto blockArg = dyn_cast(value)) + return blockArg.getOwner()->getParent(); + if (Operation* definingOp = value.getDefiningOp()) + return definingOp->getParentRegion(); + return nullptr; +} + +bool isDefinedInsideRegion(Value value, Region& region) { + Region* parentRegion = getParentRegion(value); + return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion)); +} + +Operation* getEnclosingSpatialComputeLikeOp(Value value) { + Block* block = nullptr; + if (auto blockArg = dyn_cast(value)) + block = blockArg.getOwner(); + else if (Operation* definingOp = value.getDefiningOp()) + block = definingOp->getBlock(); + + if (!block) + return nullptr; + + for (Operation* current = block->getParentOp(); current; current = current->getParentOp()) + if (isa(current)) + return current; + return nullptr; +} + +bool isTensorValueLocalToMaterializedClass(Value value, const MaterializedClass& targetClass) { + if (!isa(value.getType())) + return true; + if (isConstantLike(value)) + return true; + + Region& targetRegion = *targetClass.body->getParent(); + return isDefinedInsideRegion(value, targetRegion); +} + +bool isTensorValueDefinedInDifferentMaterializedClass(Value value, const MaterializedClass& targetClass) { + if (!isa(value.getType()) || isTensorValueLocalToMaterializedClass(value, targetClass)) + return false; + + Operation* owner = getEnclosingSpatialComputeLikeOp(value); + return owner && owner != targetClass.op; +} + +std::optional getRegionIndexInParentOp(Region* region) { + Operation* parent = region ? region->getParentOp() : nullptr; + if (!parent) + return std::nullopt; + + for (auto [index, candidate] : llvm::enumerate(parent->getRegions())) + if (&candidate == region) + return static_cast(index); + return std::nullopt; +} + +std::optional getBlockIndexInRegion(Block* block) { + Region* region = block ? block->getParent() : nullptr; + if (!region) + return std::nullopt; + + for (auto [index, candidate] : llvm::enumerate(region->getBlocks())) + if (&candidate == block) + return static_cast(index); + return std::nullopt; +} + +Block* getBlockByIndex(Region& region, unsigned blockIndex) { + unsigned index = 0; + for (Block& block : region) { + if (index == blockIndex) + return █ + ++index; + } + return nullptr; +} + +static bool isValueLegalInMaterializedClassBody(Value value, const MaterializedClass& targetClass) { + if (isConstantLike(value)) + return true; + + Region& targetRegion = *targetClass.body->getParent(); + return isDefinedInsideRegion(value, targetRegion); +} + +std::string stringifyOperationForMaterializerDebug(Operation* op) { + if (!op) + return std::string(""); + std::string storage; + llvm::raw_string_ostream stream(storage); + op->print(stream); + return storage; +} + +std::string stringifyValueForMaterializerDebug(Value value) { + std::string storage; + llvm::raw_string_ostream stream(storage); + value.print(stream); + return storage; +} + +std::string truncateMaterializerDebugString(std::string text, size_t limit = 1200) { + for (char& ch : text) + if (ch == '\n' || ch == '\r' || ch == '\t') + ch = ' '; + + if (text.size() <= limit) + return text; + text.resize(limit); + text += "..."; + return text; +} + +std::string formatMaterializerOperandListInline(Operation* op, const MaterializedClass& targetClass) { + if (!op) + return std::string(""); + + std::string storage; + llvm::raw_string_ostream stream(storage); + for (OpOperand& operand : op->getOpOperands()) { + if (operand.getOperandNumber() != 0) + stream << " | "; + Value value = operand.get(); + stream << "operand#" << operand.getOperandNumber() << " type=" << value.getType() + << " local=" << (isValueLegalInMaterializedClassBody(value, targetClass) ? 1 : 0) + << " value=" << stringifyValueForMaterializerDebug(value); + if (auto blockArg = dyn_cast(value)) { + stream << " blockArg#" << blockArg.getArgNumber(); + if (Operation* owner = blockArg.getOwner()->getParentOp()) + stream << " ownerOp='" << owner->getName() << "'"; + } else if (Operation* definingOp = value.getDefiningOp()) { + stream << " definingOp='" << definingOp->getName() << "'"; + } + } + return truncateMaterializerDebugString(stream.str()); +} + +std::string formatMaterializerParentChainInline(Operation* op) { + if (!op) + return std::string(""); + + std::string storage; + llvm::raw_string_ostream stream(storage); + unsigned depth = 0; + for (Operation* current = op; current; current = current->getParentOp()) { + if (depth != 0) + stream << " <- "; + stream << "[" << depth++ << "]" << current->getName(); + } + return truncateMaterializerDebugString(stream.str()); +} + +void attachMaterializerOperationPrintNote(InFlightDiagnostic& diagnostic, Operation* op, StringRef label) { + if (!op) + return; + diagnostic.attachNote(op->getLoc()) << label << ":\n" << stringifyOperationForMaterializerDebug(op); +} + +void attachMaterializerParentChainNote(InFlightDiagnostic& diagnostic, Operation* op, StringRef label) { + if (!op) + return; + + std::string storage; + llvm::raw_string_ostream stream(storage); + unsigned depth = 0; + for (Operation* current = op; current; current = current->getParentOp()) + stream << " [" << depth++ << "] " << current->getName() << "\n"; + + diagnostic.attachNote(op->getLoc()) << label << ":\n" << stream.str(); +} + +void attachMaterializerOperandListNote(InFlightDiagnostic& diagnostic, + Operation* op, + const MaterializedClass& targetClass, + StringRef label) { + if (!op) + return; + + std::string storage; + llvm::raw_string_ostream stream(storage); + for (OpOperand& operand : op->getOpOperands()) { + Value value = operand.get(); + stream << " operand#" << operand.getOperandNumber() << " type=" << value.getType() + << " local=" << (isValueLegalInMaterializedClassBody(value, targetClass) ? 1 : 0) + << " value=" << stringifyValueForMaterializerDebug(value); + if (auto blockArg = dyn_cast(value)) { + stream << " blockArg#" << blockArg.getArgNumber(); + if (Operation* owner = blockArg.getOwner()->getParentOp()) + stream << " ownerOp='" << owner->getName() << "'"; + } else if (Operation* definingOp = value.getDefiningOp()) { + stream << " definingOp='" << definingOp->getName() << "'"; + } + stream << "\n"; + } + + diagnostic.attachNote(op->getLoc()) << label << ":\n" << stream.str(); +} + +void attachMaterializerValueOriginNote(InFlightDiagnostic& diagnostic, Value value, StringRef label) { + if (auto blockArg = dyn_cast(value)) { + if (Operation* owner = blockArg.getOwner()->getParentOp()) + diagnostic.attachNote(owner->getLoc()) + << label << " is block argument #" << blockArg.getArgNumber() << " of '" << owner->getName() + << "' with type " << blockArg.getType(); + else + diagnostic.attachNote(UnknownLoc::get(value.getContext())) + << label << " is a top-level block argument #" << blockArg.getArgNumber() + << " with type " << blockArg.getType(); + return; + } + + if (Operation* definingOp = value.getDefiningOp()) { + diagnostic.attachNote(definingOp->getLoc()) + << label << " is defined by '" << definingOp->getName() << "' with result type " << value.getType(); + return; + } + + diagnostic.attachNote(UnknownLoc::get(value.getContext())) + << label << " has no defining operation and is not a block argument, type " << value.getType(); +} + +void attachMaterializedClassBodySummary(InFlightDiagnostic& diagnostic, const MaterializedClass& targetClass) { + Block& body = *targetClass.body; + diagnostic.attachNote(targetClass.op->getLoc()) + << "RAPTOR_MATERIALIZER_DEBUG target class " << targetClass.id << " op '" << targetClass.op->getName() + << "' body has " << body.getNumArguments() << " block arguments and " + << std::distance(body.begin(), body.end()) << " top-level operations"; +} + +FailureOr rematerializeIndexValueInClass(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Location loc, + IRMapping* mapper = nullptr); + +FailureOr rematerializeIndexOpFoldResultInClass(MaterializerState& state, + MaterializedClass& targetClass, + OpFoldResult value, + Location loc, + IRMapping* mapper = nullptr) { + if (auto attr = dyn_cast(value)) + return OpFoldResult(attr); + + FailureOr rematerialized = rematerializeIndexValueInClass(state, targetClass, cast(value), loc, mapper); + if (failed(rematerialized)) + return failure(); + return OpFoldResult(*rematerialized); +} + +FailureOr rematerializeIndexValueInClass(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Location loc, + IRMapping* mapper) { + Value originalValue = value; + bool mapperHadOriginalValue = false; + Value mappedOriginalValue; + + if (mapper && mapper->contains(value)) { + mapperHadOriginalValue = true; + Value mapped = mapper->lookup(value); + mappedOriginalValue = mapped; + if (isValueLegalInMaterializedClassBody(mapped, targetClass) || isConstantLike(mapped)) + return mapped; + value = mapped; + } + + if (isValueLegalInMaterializedClassBody(value, targetClass)) + return value; + + if (!value.getType().isIndex()) + return targetClass.op->emitError("cannot rematerialize non-index external value in materialized class body") + << " type=" << value.getType(); + + if (auto constantIndex = value.getDefiningOp()) + return getOrCreateIndexConstant(state.constantFolder, targetClass.op, constantIndex.value()); + + APInt constantValue; + if (matchPattern(value, m_ConstantInt(&constantValue))) { + if (!constantValue.isSignedIntN(64)) + return targetClass.op->emitError("cannot rematerialize out-of-range index constant") + << " value=" << llvm::toString(constantValue, 10, /*Signed=*/true); + return getOrCreateIndexConstant(state.constantFolder, targetClass.op, constantValue.getSExtValue()); + } + + if (auto affineApply = value.getDefiningOp()) { + SmallVector remappedOperands; + remappedOperands.reserve(affineApply.getMapOperands().size()); + for (Value operand : affineApply.getMapOperands()) { + FailureOr remapped = rematerializeIndexValueInClass(state, targetClass, operand, loc, mapper); + if (failed(remapped)) + return failure(); + remappedOperands.push_back(*remapped); + } + return createOrFoldAffineApply(state.rewriter, loc, affineApply.getAffineMap(), remappedOperands, state.func); + } + + if (auto addOp = value.getDefiningOp()) { + FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, addOp.getLhs(), loc, mapper); + FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, addOp.getRhs(), loc, mapper); + if (failed(lhs) || failed(rhs)) + return failure(); + return arith::AddIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); + } + + if (auto subOp = value.getDefiningOp()) { + FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, subOp.getLhs(), loc, mapper); + FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, subOp.getRhs(), loc, mapper); + if (failed(lhs) || failed(rhs)) + return failure(); + return arith::SubIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); + } + + if (auto mulOp = value.getDefiningOp()) { + FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, mulOp.getLhs(), loc, mapper); + FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, mulOp.getRhs(), loc, mapper); + if (failed(lhs) || failed(rhs)) + return failure(); + return arith::MulIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); + } + + if (auto divOp = value.getDefiningOp()) { + FailureOr lhs = rematerializeIndexValueInClass(state, targetClass, divOp.getLhs(), loc, mapper); + FailureOr rhs = rematerializeIndexValueInClass(state, targetClass, divOp.getRhs(), loc, mapper); + if (failed(lhs) || failed(rhs)) + return failure(); + return arith::DivUIOp::create(state.rewriter, loc, *lhs, *rhs).getResult(); + } + + if (auto extractOp = value.getDefiningOp()) { + SmallVector remappedIndices; + remappedIndices.reserve(extractOp.getIndices().size()); + for (Value index : extractOp.getIndices()) { + FailureOr remapped = rematerializeIndexValueInClass(state, targetClass, index, loc, mapper); + if (failed(remapped)) + return failure(); + remappedIndices.push_back(*remapped); + } + + Value tensor = extractOp.getTensor(); + if (!isConstantLike(tensor) && !isValueLegalInMaterializedClassBody(tensor, targetClass)) + return targetClass.op->emitError("cannot rematerialize indexed table lookup from external non-constant tensor") + << " tensorType=" << tensor.getType(); + return tensor::ExtractOp::create(state.rewriter, loc, tensor, remappedIndices).getResult(); + } + + if (auto blockArg = dyn_cast(value)) { + InFlightDiagnostic diagnostic = targetClass.op->emitError( + "RAPTOR_MATERIALIZER_DEBUG cannot rematerialize external block argument in materialized class body"); + diagnostic << " currentArg#" << blockArg.getArgNumber() << " currentType=" << blockArg.getType() + << " targetClass=" << targetClass.id << " targetOp='" << targetClass.op->getName() << "'"; + if (Operation* owner = blockArg.getOwner()->getParentOp()) { + diagnostic << " ownerOp='" << owner->getName() << "'"; + diagnostic << " ownerIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(owner)) << "\""; + diagnostic << " ownerChain=\"" << formatMaterializerParentChainInline(owner) << "\""; + } + diagnostic << " targetIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(targetClass.op)) << "\""; + if (mapper) { + diagnostic << " mapperPresent=1 mapperHadOriginal=" << (mapperHadOriginalValue ? 1 : 0); + if (mapperHadOriginalValue) + diagnostic << " mappedType=" << mappedOriginalValue.getType(); + } else { + diagnostic << " mapperPresent=0"; + } + attachMaterializerValueOriginNote(diagnostic, originalValue, "original value"); + if (value != originalValue) + attachMaterializerValueOriginNote(diagnostic, value, "mapped/current value"); + if (mapperHadOriginalValue && mappedOriginalValue != value) + attachMaterializerValueOriginNote(diagnostic, mappedOriginalValue, "mapper value"); + if (Operation* owner = blockArg.getOwner()->getParentOp()) { + attachMaterializerOperationPrintNote(diagnostic, owner, "RAPTOR_MATERIALIZER_DEBUG external block argument owner op"); + attachMaterializerParentChainNote(diagnostic, owner, "RAPTOR_MATERIALIZER_DEBUG external block argument owner parent chain"); + } + attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "RAPTOR_MATERIALIZER_DEBUG target materialized op"); + attachMaterializedClassBodySummary(diagnostic, targetClass); + return failure(); + } + + InFlightDiagnostic diagnostic = + targetClass.op->emitError("RAPTOR_MATERIALIZER_DEBUG cannot rematerialize external index value in materialized class body"); + diagnostic << " type=" << value.getType() << " targetClass=" << targetClass.id << " targetOp='" + << targetClass.op->getName() << "'"; + attachMaterializerValueOriginNote(diagnostic, originalValue, "original value"); + if (value != originalValue) + attachMaterializerValueOriginNote(diagnostic, value, "mapped/current value"); + attachMaterializedClassBodySummary(diagnostic, targetClass); + return failure(); +} + +InFlightDiagnostic emitNonLocalMaterializedClassValueDiagnostic(Operation* anchor, + const MaterializedClass& targetClass, + StringRef context, + Value value, + std::optional producer = std::nullopt) { + InFlightDiagnostic diagnostic = anchor->emitError(context) << " into target class " << targetClass.id; + + if (producer) { + diagnostic << " from '" << producer->instance.op->getName() << "' resultIndex=" << producer->resultIndex + << " laneStart=" << producer->instance.laneStart << " laneCount=" << producer->instance.laneCount; + } else if (auto result = dyn_cast(value)) { + diagnostic << " from '" << result.getOwner()->getName() << "' resultIndex=" << result.getResultNumber(); + } else if (auto blockArg = dyn_cast(value)) { + diagnostic << " from block argument #" << blockArg.getArgNumber(); + if (Operation* owner = blockArg.getOwner()->getParentOp()) + diagnostic << " of '" << owner->getName() << "'"; + } + + if (Operation* definingOp = value.getDefiningOp()) + diagnostic.attachNote(definingOp->getLoc()) << "offending tensor producer is '" << definingOp->getName() << "'"; + return diagnostic; +} + +FailureOr rematerializeTensorValueInClass(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef context, + IRMapping* mapper) { + auto extractSlice = value.getDefiningOp(); + if (extractSlice) { + FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( + state, targetClass, extractSlice.getSource(), anchor, context, std::nullopt, mapper); + if (failed(localizedSource)) + return failure(); + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(extractSlice.getMixedOffsets().size()); + sizes.reserve(extractSlice.getMixedSizes().size()); + strides.reserve(extractSlice.getMixedStrides().size()); + + for (OpFoldResult offset : extractSlice.getMixedOffsets()) { + FailureOr localized = + rematerializeIndexOpFoldResultInClass(state, targetClass, offset, anchor->getLoc(), mapper); + if (failed(localized)) + return failure(); + offsets.push_back(*localized); + } + for (OpFoldResult size : extractSlice.getMixedSizes()) { + FailureOr localized = + rematerializeIndexOpFoldResultInClass(state, targetClass, size, anchor->getLoc(), mapper); + if (failed(localized)) + return failure(); + sizes.push_back(*localized); + } + for (OpFoldResult stride : extractSlice.getMixedStrides()) { + FailureOr localized = + rematerializeIndexOpFoldResultInClass(state, targetClass, stride, anchor->getLoc(), mapper); + if (failed(localized)) + return failure(); + strides.push_back(*localized); + } + + return tensor::ExtractSliceOp::create(state.rewriter, anchor->getLoc(), *localizedSource, offsets, sizes, strides) + .getResult(); + } + + if (auto collapseShape = value.getDefiningOp()) { + FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( + state, targetClass, collapseShape.getSrc(), anchor, context, std::nullopt, mapper); + if (failed(localizedSource)) + return failure(); + return tensor::CollapseShapeOp::create( + state.rewriter, anchor->getLoc(), *localizedSource, collapseShape.getReassociationIndices()) + .getResult(); + } + + return failure(); +} + +FailureOr materializeTensorValueForMaterializedClassUse(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef context, + std::optional producer, + IRMapping* mapper) { + if (mapper && mapper->contains(value)) + value = mapper->lookup(value); + + if (!isa(value.getType()) || isConstantLike(value) || isTensorValueLocalToMaterializedClass(value, targetClass)) + return value; + + if (value.getDefiningOp() || value.getDefiningOp()) { + FailureOr rematerialized = rematerializeTensorValueInClass(state, targetClass, value, anchor, context, mapper); + if (failed(rematerialized)) + return failure(); + return *rematerialized; + } + + if (isTensorValueDefinedInDifferentMaterializedClass(value, targetClass)) { + emitNonLocalMaterializedClassValueDiagnostic(anchor, targetClass, context, value, producer); + return failure(); + } + + return appendInput(state, targetClass, value); +} + +std::optional mapExternalRegionBlockArgumentToLocalClone(const MaterializedClass& targetClass, + Operation* anchor, + BlockArgument externalArg) { + Block* sourceBlock = externalArg.getOwner(); + Region* sourceRegion = sourceBlock ? sourceBlock->getParent() : nullptr; + Operation* sourceParent = sourceRegion ? sourceRegion->getParentOp() : nullptr; + if (!sourceParent || !anchor) + return std::nullopt; + + std::optional sourceRegionIndex = getRegionIndexInParentOp(sourceRegion); + std::optional sourceBlockIndex = getBlockIndexInRegion(sourceBlock); + if (!sourceRegionIndex || !sourceBlockIndex) + return std::nullopt; + + for (Operation* current = anchor->getParentOp(); current && current != targetClass.op; + current = current->getParentOp()) { + if (current->getName() != sourceParent->getName()) + continue; + if (current->getNumRegions() <= *sourceRegionIndex) + continue; + + Region& localRegion = current->getRegion(*sourceRegionIndex); + Block* localBlock = getBlockByIndex(localRegion, *sourceBlockIndex); + if (!localBlock || localBlock->getNumArguments() <= externalArg.getArgNumber()) + continue; + + BlockArgument localArg = localBlock->getArgument(externalArg.getArgNumber()); + if (localArg.getType() != externalArg.getType()) + continue; + if (!isValueLegalInMaterializedClassBody(localArg, targetClass)) + continue; + return localArg; + } + + return std::nullopt; +} + +FailureOr localizeMaterializedClassOperand(MaterializerState& state, + MaterializedClass& targetClass, + Value value, + Operation* anchor, + StringRef tensorContext, + StringRef genericContext, + IRMapping* mapper) { + if (mapper && mapper->contains(value)) + value = mapper->lookup(value); + + if (auto blockArg = dyn_cast(value)) + if (std::optional localArg = mapExternalRegionBlockArgumentToLocalClone(targetClass, anchor, blockArg)) + return *localArg; + + if (isa(value.getType())) + return materializeTensorValueForMaterializedClassUse(state, targetClass, value, anchor, tensorContext, std::nullopt, mapper); + + if (isValueLegalInMaterializedClassBody(value, targetClass)) + return value; + + if (value.getType().isIndex()) + return rematerializeIndexValueInClass(state, targetClass, value, anchor->getLoc(), mapper); + + InFlightDiagnostic diagnostic = anchor->emitError(genericContext); + diagnostic << " type=" << value.getType(); + if (auto blockArg = dyn_cast(value)) { + diagnostic << " blockArg#" << blockArg.getArgNumber(); + if (Operation* owner = blockArg.getOwner()->getParentOp()) + diagnostic.attachNote(owner->getLoc()) << "block argument belongs to '" << owner->getName() << "'"; + } else if (Operation* definingOp = value.getDefiningOp()) { + diagnostic.attachNote(definingOp->getLoc()) << "unsupported external operand producer is '" << definingOp->getName() + << "'"; + } + return failure(); +} + +// ----------------------------------------------------------------------------- +// Tensor packing helpers. +// ----------------------------------------------------------------------------- + +struct Dim0SliceParams { + SmallVector offsets; + SmallVector sizes; + SmallVector strides; +}; + +Dim0SliceParams +buildDim0SliceParams(OpBuilder& builder, RankedTensorType referenceType, OpFoldResult firstOffset, int64_t firstSize) { + Dim0SliceParams params; + params.offsets.reserve(referenceType.getRank()); + params.sizes.reserve(referenceType.getRank()); + params.strides.reserve(referenceType.getRank()); + + params.offsets.push_back(firstOffset); + params.sizes.push_back(builder.getIndexAttr(firstSize)); + params.strides.push_back(builder.getIndexAttr(1)); + + for (int64_t dim = 1; dim < referenceType.getRank(); ++dim) { + params.offsets.push_back(builder.getIndexAttr(0)); + params.sizes.push_back(builder.getIndexAttr(referenceType.getDimSize(dim))); + params.strides.push_back(builder.getIndexAttr(1)); + } + + return params; +} + +Value createDim0ExtractSlice( + MaterializerState& state, Location loc, Value source, OpFoldResult firstOffset, int64_t firstSize) { + auto sourceType = cast(source.getType()); + Dim0SliceParams params = buildDim0SliceParams(state.rewriter, sourceType, firstOffset, firstSize); + return tensor::ExtractSliceOp::create(state.rewriter, loc, source, params.offsets, params.sizes, params.strides) + .getResult(); +} + +FailureOr createDim0ExtractSliceInClass(MaterializerState& state, + MaterializedClass& targetClass, + Location loc, + Value source, + OpFoldResult firstOffset, + int64_t firstSize) { + FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + source, + targetClass.op, + "createDim0ExtractSliceInClass tried to reuse a tensor from another materialized class"); + if (failed(localizedSource)) + return failure(); + FailureOr localizedOffset = + rematerializeIndexOpFoldResultInClass(state, targetClass, firstOffset, loc); + if (failed(localizedOffset)) + return failure(); + return createDim0ExtractSlice(state, loc, *localizedSource, *localizedOffset, firstSize); +} + +Value createStaticExtractSlice(MaterializerState& state, + Location loc, + Value source, + ArrayRef sliceOffsets, + ArrayRef resultShape) { + auto sourceType = cast(source.getType()); + assert(sliceOffsets.size() == static_cast(sourceType.getRank()) && "offset rank mismatch"); + assert(resultShape.size() == static_cast(sourceType.getRank()) && "result rank mismatch"); + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(sourceType.getRank()); + sizes.reserve(sourceType.getRank()); + strides.reserve(sourceType.getRank()); + + for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { + offsets.push_back(sliceOffsets[dim]); + sizes.push_back(state.rewriter.getIndexAttr(resultShape[dim])); + strides.push_back(state.rewriter.getIndexAttr(1)); + } + + return tensor::ExtractSliceOp::create(state.rewriter, loc, source, offsets, sizes, strides).getResult(); +} + +FailureOr createStaticExtractSliceInClass(MaterializerState& state, + MaterializedClass& targetClass, + Location loc, + Value source, + ArrayRef sliceOffsets, + ArrayRef resultShape) { + FailureOr localizedSource = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + source, + targetClass.op, + "createStaticExtractSliceInClass tried to reuse a tensor from another materialized class"); + if (failed(localizedSource)) + return failure(); + + SmallVector localizedOffsets; + localizedOffsets.reserve(sliceOffsets.size()); + for (OpFoldResult offset : sliceOffsets) { + FailureOr localized = + rematerializeIndexOpFoldResultInClass(state, targetClass, offset, loc); + if (failed(localized)) + return failure(); + localizedOffsets.push_back(*localized); + } + return createStaticExtractSlice(state, loc, *localizedSource, localizedOffsets, resultShape); +} + +Value createIndexedIndexValue(MaterializerState& state, + Operation* anchor, + ArrayRef values, + Value index, + Location loc, + std::optional preferredPeriod = std::nullopt, + bool allowExhaustiveTiledSearch = true); + +FailureOr> buildProjectedFragmentOffsetsInClass(MaterializerState& state, + MaterializedClass& targetClass, + const ProjectedTransferDescriptor& descriptor, + Value flatFragmentIndex, + Location loc) { + FailureOr localizedIndex = rematerializeIndexValueInClass(state, targetClass, flatFragmentIndex, loc); + if (failed(localizedIndex)) + return failure(); + SmallVector fragmentOffsets; + fragmentOffsets.reserve(descriptor.layout.fragmentShape.size()); + for (ArrayRef dimOffsets : descriptor.fragmentOffsetsByDim) + fragmentOffsets.push_back(createIndexedIndexValue(state, + targetClass.op, + dimOffsets, + *localizedIndex, + loc, + static_cast(descriptor.layout.payloadFragmentCount), + /*allowExhaustiveTiledSearch=*/false)); + return fragmentOffsets; +} + +Value createDim0InsertSlice( + MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) { + auto fragmentType = cast(fragment.getType()); + Dim0SliceParams params = buildDim0SliceParams(state.rewriter, fragmentType, firstOffset, fragmentType.getDimSize(0)); + return tensor::InsertSliceOp::create( + state.rewriter, loc, fragment, destination, params.offsets, params.sizes, params.strides) + .getResult(); +} + +FailureOr createDim0InsertSliceInClass(MaterializerState& state, + MaterializedClass& targetClass, + Location loc, + Value fragment, + Value destination, + OpFoldResult firstOffset) { + FailureOr localizedFragment = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + fragment, + targetClass.op, + "createDim0InsertSliceInClass tried to reuse a fragment tensor from another materialized class"); + if (failed(localizedFragment)) + return failure(); + FailureOr localizedDestination = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + destination, + targetClass.op, + "createDim0InsertSliceInClass tried to reuse a destination tensor from another materialized class"); + if (failed(localizedDestination)) + return failure(); + FailureOr localizedOffset = + rematerializeIndexOpFoldResultInClass(state, targetClass, firstOffset, loc); + if (failed(localizedOffset)) + return failure(); + return createDim0InsertSlice(state, loc, *localizedFragment, *localizedDestination, *localizedOffset); +} + +void createDim0ParallelInsertSlice( + MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) { + auto fragmentType = cast(fragment.getType()); + Dim0SliceParams params = buildDim0SliceParams(state.rewriter, fragmentType, firstOffset, fragmentType.getDimSize(0)); + tensor::ParallelInsertSliceOp::create( + state.rewriter, loc, fragment, destination, params.offsets, params.sizes, params.strides); +} + +Value scaleIndexByDim0Size(MaterializerState& state, Operation* anchor, Value index, int64_t dim0Size, Location loc) { + if (dim0Size == 1) + return index; + + Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, anchor, dim0Size); + return arith::MulIOp::create(state.rewriter, loc, index, dim0SizeValue).getResult(); +} + +FailureOr scaleIndexByDim0SizeInClass(MaterializerState& state, + MaterializedClass& targetClass, + Value index, + int64_t dim0Size, + Location loc) { + FailureOr localizedIndex = rematerializeIndexValueInClass(state, targetClass, index, loc); + if (failed(localizedIndex)) + return failure(); + if (dim0Size == 1) + return *localizedIndex; + + Value dim0SizeValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, dim0Size); + return arith::MulIOp::create(state.rewriter, loc, *localizedIndex, dim0SizeValue).getResult(); +} + +bool sameProducerResult(ProducerKey lhs, ProducerKey rhs) { + return lhs.instance.op == rhs.instance.op && lhs.resultIndex == rhs.resultIndex; +} + +bool containsProducerKey(ProducerKey outer, ProducerKey inner) { + if (!sameProducerResult(outer, inner)) + return false; + if (!isa(outer.instance.op)) + return false; + if (outer.instance.laneCount == 0 || inner.instance.laneCount == 0) + return false; + + uint32_t outerStart = outer.instance.laneStart; + uint32_t outerEnd = outerStart + outer.instance.laneCount; + uint32_t innerStart = inner.instance.laneStart; + uint32_t innerEnd = innerStart + inner.instance.laneCount; + + return outerStart <= innerStart && innerEnd <= outerEnd; +} + +std::optional extractPackedProducerSlice(MaterializerState& state, + MaterializedClass& materializedClass, + ProducerKey packedKey, + Value packed, + ProducerKey requestedKey) { + if (!containsProducerKey(packedKey, requestedKey)) + return std::nullopt; + + auto packedType = dyn_cast(packed.getType()); + if (!packedType || !packedType.hasStaticShape() || packedType.getRank() == 0) + return std::nullopt; + + if (packedKey.instance.laneCount == 0) + return std::nullopt; + + int64_t packedRows = packedType.getDimSize(0); + if (packedRows % static_cast(packedKey.instance.laneCount) != 0) + return std::nullopt; + + int64_t rowsPerLane = packedRows / static_cast(packedKey.instance.laneCount); + int64_t rowOffset = + static_cast(requestedKey.instance.laneStart - packedKey.instance.laneStart) * rowsPerLane; + int64_t rowCount = static_cast(requestedKey.instance.laneCount) * rowsPerLane; + + state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); + + Value firstOffset = getOrCreateIndexConstant(state.constantFolder, materializedClass.op, rowOffset); + return createDim0ExtractSlice(state, materializedClass.op->getLoc(), packed, firstOffset, rowCount); +} + +std::optional AvailableValueStore::lookupExact(ProducerKey key, ClassId classId) const { + auto producerIt = exactValues.find(key); + if (producerIt == exactValues.end()) + return std::nullopt; + + auto valueIt = producerIt->second.find(classId); + if (valueIt == producerIt->second.end()) + return std::nullopt; + + return valueIt->second; +} + +Value getPackedSliceForRunIndex(MaterializerState& state, + Operation* anchor, + Value packed, + RankedTensorType fragmentType, + size_t index, + Location loc) { + int64_t rowOffset = static_cast(index) * fragmentType.getDimSize(0); + Value firstOffset = getOrCreateIndexConstant(state.constantFolder, anchor, rowOffset); + return createDim0ExtractSlice(state, loc, packed, firstOffset, fragmentType.getDimSize(0)); +} + +FailureOr createReceiveConcatLoop(MaterializerState& state, + MaterializedClass& targetClass, + RankedTensorType concatType, + RankedTensorType fragmentType, + const MessageVector& messages, + Location loc); + +using IndexedFragmentBuilder = llvm::function_ref(Value flatIndex)>; +using IndexedInsertOffsetBuilder = llvm::function_ref(Value flatIndex)>; + +FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& state, + MaterializedClass& targetClass, + PackedScalarRunValue& run, + Location loc); + +bool isDeferredLocalPackedScalarRun(const PackedScalarRunValue& run) { + return run.kind == PackedScalarRunKind::DeferredLocalCompute; +} + +size_t getPackedScalarRunReceiveCount(const PackedScalarRunValue& run) { + size_t count = 0; + for (const PackedScalarRunSlot& slot : run.slots) + count += slot.keys.size(); + return count; +} + +LogicalResult validatePackedScalarRunMetadata(Operation* anchor, const PackedScalarRunValue& run) { + if (run.kind == PackedScalarRunKind::DeferredLocalCompute) + return success(); + + size_t receiveCount = getPackedScalarRunReceiveCount(run); + + if (receiveCount == 0) + return anchor->emitError("packed scalar run has no receives"); + + if (failed(run.messages.verify(anchor))) + return failure(); + + if (run.messages.size() != receiveCount) + return anchor->emitError("packed scalar run receive metadata count is inconsistent"); + + return success(); +} + +FailureOr materializePackedScalarRunValue(MaterializerState& state, + MaterializedClass& targetClass, + PackedScalarRunValue& run, + Location loc) { + if (run.packed) + return run.packed; + + if (run.kind == PackedScalarRunKind::Materialized) + return targetClass.op->emitError("materialized packed scalar run has no packed value"); + + if (isDeferredLocalPackedScalarRun(run)) + return materializeDeferredLocalPackedScalarRunValue(state, targetClass, run, loc); + + if (failed(validatePackedScalarRunMetadata(targetClass.op, run))) + return failure(); + + FailureOr fullPackedType = + getPackedBatchTensorType(run.fragmentType, getPackedScalarRunReceiveCount(run)); + if (failed(fullPackedType)) + return targetClass.op->emitError("cannot create lazy packed scalar run receive type"); + + auto packed = createReceiveConcatLoop(state, targetClass, *fullPackedType, run.fragmentType, run.messages, loc); + if (failed(packed)) + return failure(); + run.packed = *packed; + return run.packed; +} + +std::optional AvailableValueStore::lookupPackedRun(MaterializerState& state, ProducerKey key, ClassId classId) { + for (PackedScalarRunValue& run : packedScalarRuns) { + if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) + continue; + + for (auto [slotIndex, slot] : llvm::enumerate(run.slots)) { + std::optional contiguousKey = getContiguousProducerRangeForKeys(slot.keys); + auto exactKeyIt = llvm::find(slot.keys, key); + if ((!contiguousKey || !containsProducerKey(*contiguousKey, key)) && exactKeyIt == slot.keys.end()) + continue; + + FailureOr slotPackedType = getPackedBatchTensorType(run.fragmentType, slot.keys.size()); + if (failed(slotPackedType)) + return std::nullopt; + + MaterializedClass& materializedClass = state.classes[classId]; + state.rewriter.setInsertionPoint(materializedClass.body->getTerminator()); + + FailureOr packed = + materializePackedScalarRunValue(state, materializedClass, run, materializedClass.op->getLoc()); + if (failed(packed)) + return std::nullopt; + + Value slotPacked = + getPackedSliceForRunIndex(state, materializedClass.op, *packed, *slotPackedType, slotIndex, (*packed).getLoc()); + + if (contiguousKey && *contiguousKey == key) { + record(key, classId, slotPacked); + return slotPacked; + } + + if (contiguousKey && containsProducerKey(*contiguousKey, key)) { + std::optional sliced = + extractPackedProducerSlice(state, materializedClass, *contiguousKey, slotPacked, key); + if (!sliced) + return std::nullopt; + + record(key, classId, *sliced); + return *sliced; + } + + if (exactKeyIt != slot.keys.end() && key.instance.laneCount == 1) { + size_t keyIndex = static_cast(std::distance(slot.keys.begin(), exactKeyIt)); + Value sliced = getPackedSliceForRunIndex( + state, materializedClass.op, slotPacked, run.fragmentType, keyIndex, (*packed).getLoc()); + record(key, classId, sliced); + return sliced; + } + } + } + + return std::nullopt; +} + +IndexedBatchRunValue* AvailableValueStore::lookupIndexedBatchRun(ProducerKey key, ClassId classId) { + for (IndexedBatchRunValue& run : indexedBatchRuns) { + if (run.targetClass != classId || run.sourceOp != key.instance.op || run.resultIndex != key.resultIndex) + continue; + for (const PackedScalarRunSlot& slot : run.slots) { + if (!llvm::is_contained(slot.keys, key)) + continue; + return &run; + } + } + return nullptr; +} + +std::optional AvailableValueStore::lookup(MaterializerState& state, ProducerKey key, ClassId classId) { + + if (std::optional exact = lookupExact(key, classId)) { + return exact; + } + + if (std::optional packedRunValue = lookupPackedRun(state, key, classId)) + return packedRunValue; + + MaterializedClass& materializedClass = state.classes[classId]; + + for (const auto& [candidateKey, classValues] : exactValues) { + if (!sameProducerResult(candidateKey, key) || !containsProducerKey(candidateKey, key)) + continue; + + auto valueIt = classValues.find(classId); + if (valueIt == classValues.end()) + continue; + + std::optional slice = + extractPackedProducerSlice(state, materializedClass, candidateKey, valueIt->second, key); + if (!slice) + return std::nullopt; + + record(key, classId, *slice); + return *slice; + } + return std::nullopt; +} + +Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef values) { + SmallVector elements; + elements.reserve(values.size()); + for (int64_t value : values) + elements.push_back(APInt(64, value)); + + auto type = RankedTensorType::get({static_cast(values.size())}, state.rewriter.getIndexType()); + auto attr = DenseIntElementsAttr::get(type, elements); + return getOrCreateConstant(state.constantFolder, anchor, attr, type); +} + +bool allEqual(ArrayRef values) { + assert(!values.empty() && "expected at least one value"); + for (int64_t value : values.drop_front()) + if (value != values.front()) + return false; + return true; +} + +struct IndexedIndexPattern { + int64_t base = 0; + int64_t step = 0; + int64_t period = 1; + int64_t innerStep = 0; + int64_t outerStep = 0; + bool isTiled = false; +}; + +bool matchAffineSequence(ArrayRef values, IndexedIndexPattern& pattern) { + assert(!values.empty() && "expected at least one value"); + + pattern.base = values.front(); + pattern.step = values.size() == 1 ? 0 : values[1] - values[0]; + pattern.isTiled = false; + + for (auto [index, value] : llvm::enumerate(values)) { + int64_t expected = pattern.base + pattern.step * static_cast(index); + if (value != expected) + return false; + } + + return true; +} + +bool matchTiledAffineSequence(ArrayRef values, IndexedIndexPattern& pattern, int64_t period) { + assert(!values.empty() && "expected at least one value"); + if (period < 2 || period > static_cast(values.size() / 2)) + return false; + + int64_t base = values.front(); + int64_t innerStep = values[1] - values[0]; + int64_t outerStep = values[period] - values[0]; + + for (auto [index, value] : llvm::enumerate(values)) { + int64_t i = static_cast(index); + int64_t expected = base + outerStep * (i / period) + innerStep * (i % period); + if (value != expected) + return false; + } + + pattern.base = base; + pattern.period = period; + pattern.innerStep = innerStep; + pattern.outerStep = outerStep; + pattern.isTiled = true; + return true; +} + +bool matchTiledAffineSequence(ArrayRef values, IndexedIndexPattern& pattern) { + assert(!values.empty() && "expected at least one value"); + + for (int64_t period = 2; period <= static_cast(values.size() / 2); ++period) + if (matchTiledAffineSequence(values, pattern, period)) + return true; + + return false; +} + +std::optional getIndexedIndexPattern(ArrayRef values, + std::optional preferredPeriod = std::nullopt, + bool allowExhaustiveTiledSearch = true) { + assert(!values.empty() && "expected at least one value"); + + IndexedIndexPattern pattern; + if (matchAffineSequence(values, pattern)) + return pattern; + if (preferredPeriod && matchTiledAffineSequence(values, pattern, *preferredPeriod)) + return pattern; + if (allowExhaustiveTiledSearch && values.size() <= 256 && matchTiledAffineSequence(values, pattern)) + return pattern; + + return std::nullopt; +} + +Value createAffineIndexValue(MaterializerState& state, const IndexedIndexPattern& pattern, Value index, Location loc) { + MLIRContext* context = state.func.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + + AffineExpr expr; + if (!pattern.isTiled) { + expr = getAffineConstantExpr(pattern.base, context) + d0 * pattern.step; + } + else { + expr = getAffineConstantExpr(pattern.base, context) + d0.floorDiv(pattern.period) * pattern.outerStep + + (d0 % pattern.period) * pattern.innerStep; + } + + AffineMap map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); + return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {index}, state.func); +} + +Value createIndexedIndexValue(MaterializerState& state, + Operation* anchor, + ArrayRef values, + Value index, + Location loc, + std::optional preferredPeriod, + bool allowExhaustiveTiledSearch) { + assert(!values.empty() && "expected at least one indexed value"); + + if (allEqual(values)) { + return getOrCreateIndexConstant(state.constantFolder, anchor, values.front()); + } + + if (std::optional pattern = + getIndexedIndexPattern(values, preferredPeriod, allowExhaustiveTiledSearch)) + return createAffineIndexValue(state, *pattern, index, loc); + Value table = createIndexTensorConstant(state, anchor, values); + return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {index}).getResult(); +} + +Value createIndexedIndexValue( + MaterializerState& state, Operation* anchor, ArrayRef values, Value index, Location loc) { + assert(!values.empty() && "expected at least one indexed value"); + + SmallVector widened; + widened.reserve(values.size()); + for (int32_t value : values) + widened.push_back(value); + + return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, std::nullopt, true); +} + +OpFoldResult createIndexedOrStaticIndex(MaterializerState& state, + Operation* anchor, + ArrayRef values, + Value index, + Location loc) { + assert(!values.empty() && "expected at least one indexed value"); + if (allEqual(values)) + return state.rewriter.getIndexAttr(values.front()); + return createIndexedIndexValue(state, anchor, values, index, loc); +} + +Value createIndexedChannelId( + MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { + return createIndexedIndexValue(state, anchor, ArrayRef(messages.channelIds), index, loc); +} + +Value createIndexedChannelId(MaterializerState& state, + Operation* anchor, + const MessageVector& messages, + Value index, + Location loc, + std::optional preferredPeriod) { + return createIndexedIndexValue( + state, anchor, ArrayRef(messages.channelIds), index, loc, preferredPeriod, true); +} + +Value createIndexedSourceCoreId( + MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { + return createIndexedIndexValue(state, anchor, ArrayRef(messages.sourceCoreIds), index, loc); +} + +Value createIndexedSourceCoreId(MaterializerState& state, + Operation* anchor, + const MessageVector& messages, + Value index, + Location loc, + std::optional preferredPeriod) { + SmallVector widened(messages.sourceCoreIds.begin(), messages.sourceCoreIds.end()); + return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, preferredPeriod, true); +} + +Value createIndexedTargetCoreId( + MaterializerState& state, Operation* anchor, const MessageVector& messages, Value index, Location loc) { + return createIndexedIndexValue(state, anchor, ArrayRef(messages.targetCoreIds), index, loc); +} + +Value createIndexedTargetCoreId(MaterializerState& state, + Operation* anchor, + const MessageVector& messages, + Value index, + Location loc, + std::optional preferredPeriod) { + SmallVector widened(messages.targetCoreIds.begin(), messages.targetCoreIds.end()); + return createIndexedIndexValue(state, anchor, ArrayRef(widened), index, loc, preferredPeriod, true); +} + +Value createLaneIndexedIndexValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef values, + Location loc) { + assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class"); + assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane"); + + auto batch = cast(materializedClass.op); + auto laneArg = batch.getLaneArgument(); + assert(laneArg && "expected compute_batch lane argument"); + + return createIndexedIndexValue(state, materializedClass.op, values, *laneArg, loc); +} + +Value createLaneIndexedIndexValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef values, + Location loc) { + assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class"); + assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane"); + + SmallVector widened; + widened.reserve(values.size()); + for (int32_t value : values) + widened.push_back(value); + + return createLaneIndexedIndexValue(state, materializedClass, ArrayRef(widened), loc); +} + +FailureOr> +getPeerLogicalInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId logicalSlot) { + SmallVector peers; + peers.reserve(materializedClass.cpus.size()); + for (CpuId cpu : materializedClass.cpus) { + auto streamIt = state.logicalInstancesByCpu.find(cpu); + if (streamIt == state.logicalInstancesByCpu.end() || logicalSlot >= streamIt->second.size()) + return failure(); + peers.push_back(streamIt->second[logicalSlot]); + } + return peers; +} + +Value createOriginalLaneValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef peers, + Location loc) { + assert(!peers.empty() && "expected at least one peer instance"); + if (!materializedClass.isBatch) + return getOrCreateIndexConstant(state.constantFolder, materializedClass.op, peers.front().laneStart); + + auto batch = cast(materializedClass.op); + auto laneArg = batch.getLaneArgument(); + assert(laneArg && "expected materialized compute_batch lane argument"); + + SmallVector laneValues; + laneValues.reserve(peers.size()); + for (const ComputeInstance& peer : peers) + laneValues.push_back(peer.laneStart); + + return createIndexedIndexValue(state, materializedClass.op, ArrayRef(laneValues), *laneArg, loc); +} + +bool hasLiveExternalUse(Value value, const DenseSet& oldComputeOps) { + SmallVector worklist {value}; + DenseSet visited; + + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + if (!visited.insert(current).second) + continue; + + for (OpOperand& use : current.getUses()) { + Operation* owner = use.getOwner(); + if (isInsideOldCompute(owner, oldComputeOps)) + continue; + if (isa(owner)) { + for (Value result : owner->getResults()) + worklist.push_back(result); + continue; + } + return true; + } + } + + return false; +} + +bool hasRealComputeConsumer(Value value, const DenseSet& oldComputeOps) { + SmallVector worklist {value}; + DenseSet visited; + + while (!worklist.empty()) { + Value current = worklist.pop_back_val(); + if (!visited.insert(current).second) + continue; + + for (OpOperand& use : current.getUses()) { + Operation* owner = use.getOwner(); + if (isInsideOldCompute(owner, oldComputeOps)) + continue; + if (isa(owner)) { + for (Value result : owner->getResults()) + worklist.push_back(result); + continue; + } + if (isa(owner)) + continue; + return true; + } + } + + return false; +} + +FailureOr +getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex); + +bool isTerminalHostBatchOutput(Value output, const DenseSet& oldComputeOps) { + auto batch = dyn_cast_or_null(output.getDefiningOp()); + if (!batch || batch.getNumResults() == 0) + return false; + if (!hasLiveExternalUse(output, oldComputeOps)) + return false; + return !hasRealComputeConsumer(output, oldComputeOps); +} + + +void appendDestinationClass(MaterializerState& state, ProducerKey key, ClassId classId) { + SmallVector& destinations = state.producerDestClasses[key]; + if (!llvm::is_contained(destinations, classId)) + destinations.push_back(classId); +} + +void replaceLiveExternalUses(Value oldValue, Value replacement, const DenseSet& oldComputeOps) { + SmallVector uses; + for (OpOperand& use : oldValue.getUses()) + uses.push_back(&use); + + for (OpOperand* use : uses) { + Operation* owner = use->getOwner(); + if (isInsideOldCompute(owner, oldComputeOps)) + continue; + use->set(replacement); + } +} + +LogicalResult collectProducerDestinations(MaterializerState& state) { + return forEachLogicalConsumerInMaterializationOrder( + state, + [&](CpuId, ClassId targetClass, ComputeInstance scheduledConsumer, ComputeInstance logicalConsumer, SlotId) + -> LogicalResult { + SmallVector consumerInputs = getComputeInstanceInputs(scheduledConsumer); + for (auto [inputIndex, input] : llvm::enumerate(consumerInputs)) { + SmallVector producerKeys; + if (auto batchConsumer = dyn_cast(logicalConsumer.op)) + producerKeys = collectProducerKeysForBatchInputDestinations( + state, batchConsumer, static_cast(inputIndex), input, logicalConsumer); + else + producerKeys = collectProducerKeysForDestinations(input, logicalConsumer); + + for (ProducerKey producerKey : producerKeys) { + ComputeInstance scheduledProducer = getScheduledChunkForLogicalInstance(state, producerKey.instance); + auto producerCpuIt = state.schedule.computeToCpuMap.find(scheduledProducer); + if (producerCpuIt == state.schedule.computeToCpuMap.end()) + return logicalConsumer.op->emitError( + "schedule materialization found an input produced by an unscheduled compute"); + + ClassId sourceClass = state.cpuToClass.lookup(producerCpuIt->second); + if (sourceClass == targetClass) { + SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, targetClass}; + SmallVector& bucket = state.sameClassConsumerIndex[lookupKey]; + if (!llvm::is_contained(bucket, producerKey)) + bucket.push_back(producerKey); + continue; + } + + appendDestinationClass(state, producerKey, targetClass); + } + } + + return success(); + }); +} + +bool isStaticSliceInBounds(ArrayRef offsets, RankedTensorType sourceType, RankedTensorType fragmentType) { + if (offsets.size() != static_cast(sourceType.getRank()) + || offsets.size() != static_cast(fragmentType.getRank())) + return false; + + for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { + int64_t offset = offsets[dim]; + if (offset < 0) + return false; + + int64_t sourceDimSize = sourceType.getDimSize(dim); + int64_t fragmentDimSize = fragmentType.getDimSize(dim); + if (fragmentDimSize < 0 || sourceDimSize < 0 || fragmentDimSize > sourceDimSize) + return false; + if (offset > sourceDimSize - fragmentDimSize) + return false; + } + + return true; +} + + +bool isStaticSliceContainedIn(ArrayRef innerOffsets, + ArrayRef innerSizes, + ArrayRef outerOffsets, + ArrayRef outerSizes) { + if (innerOffsets.size() != innerSizes.size() || outerOffsets.size() != outerSizes.size() + || innerOffsets.size() != outerOffsets.size()) + return false; + + for (size_t dim = 0; dim < innerOffsets.size(); ++dim) { + if (innerSizes[dim] < 0 || outerSizes[dim] < 0) + return false; + + int64_t innerBegin = innerOffsets[dim]; + int64_t innerEnd = innerBegin + innerSizes[dim]; + int64_t outerBegin = outerOffsets[dim]; + int64_t outerEnd = outerBegin + outerSizes[dim]; + if (innerBegin < outerBegin || innerEnd > outerEnd) + return false; + } + + return true; +} + +bool areAllUnitStrides(ArrayRef strides) { + return llvm::all_of(strides, [](int64_t stride) { return stride == 1; }); +} + +static std::optional getStaticForTripCount(scf::ForOp loop) { + std::optional lowerBound = matchConstantIndexValue(loop.getLowerBound()); + std::optional upperBound = matchConstantIndexValue(loop.getUpperBound()); + std::optional step = matchConstantIndexValue(loop.getStep()); + if (!lowerBound || !upperBound || !step || *step <= 0 || *upperBound < *lowerBound) + return std::nullopt; + + int64_t distance = *upperBound - *lowerBound; + return (distance + *step - 1) / *step; +} + +static SmallVector collectEnclosingStaticProjectedLoops(Operation* op) { + SmallVector loops; + SmallVector reversedLoops; + for (Operation* current = op->getParentOp(); current; current = current->getParentOp()) + if (auto loop = dyn_cast(current)) + reversedLoops.push_back(loop); + + for (scf::ForOp loop : llvm::reverse(reversedLoops)) { + std::optional lowerBound = matchConstantIndexValue(loop.getLowerBound()); + std::optional step = matchConstantIndexValue(loop.getStep()); + std::optional tripCount = getStaticForTripCount(loop); + if (!lowerBound || !step || !tripCount) + return {}; + loops.push_back(StaticProjectedLoopInfo {.iv = cast(loop.getInductionVar()), + .lowerBound = *lowerBound, + .step = *step, + .tripCount = *tripCount}); + } + return loops; +} + +static bool +isProjectedOffsetValue(Value value, Value laneArg, ArrayRef loops, bool& usesDynamicBinding) { + if (value == laneArg) { + usesDynamicBinding = true; + return true; + } + + for (const StaticProjectedLoopInfo& loop : loops) { + if (value == loop.iv) { + usesDynamicBinding = true; + return true; + } + } + + if (matchPattern(value, m_Constant())) + return true; + + auto affineApply = value.getDefiningOp(); + if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) + return false; + + bool nestedUsesDynamicBinding = false; + for (Value operand : affineApply.getMapOperands()) { + bool operandUsesDynamicBinding = false; + if (!isProjectedOffsetValue(operand, laneArg, loops, operandUsesDynamicBinding)) + return false; + nestedUsesDynamicBinding = nestedUsesDynamicBinding || operandUsesDynamicBinding; + } + + usesDynamicBinding = usesDynamicBinding || nestedUsesDynamicBinding; + return true; +} + +static std::optional getConstantIndex(OpFoldResult value); + +static unsigned getProjectedFragmentsPerLogicalSlot(ArrayRef loopTripCounts) { + unsigned fragmentsPerLogicalSlot = 1; + for (int64_t tripCount : loopTripCounts) { + assert(tripCount > 0 && "projected loop trip counts must be positive"); + fragmentsPerLogicalSlot *= static_cast(tripCount); + } + return fragmentsPerLogicalSlot; +} + +LogicalResult verifyProjectedFragmentLayout(Operation* anchor, const ProjectedFragmentLayout& layout) { + if (!layout.fragmentType || layout.fragmentShape.empty()) + return anchor->emitError("projected fragment layout is missing fragment type metadata"); + if (layout.fragmentShape.size() != static_cast(layout.fragmentType.getRank())) + return anchor->emitError("projected fragment layout rank does not match fragment type"); + if (layout.payloadFragmentCount == 0 || layout.fragmentsPerLogicalSlot == 0) + return anchor->emitError("projected fragment layout has an invalid fragment count"); + if (layout.payloadFragmentCount % layout.fragmentsPerLogicalSlot != 0) + return anchor->emitError("projected fragment layout payload fragment count is incompatible with logical slots"); + return success(); +} + +FailureOr +getProjectedPayloadType(Operation* anchor, RankedTensorType fragmentType, unsigned payloadFragmentCount) { + if (failed( + verifyPackableFragmentType(anchor, fragmentType, payloadFragmentCount, "cannot create projected payload type"))) + return failure(); + return getPackedBatchTensorType(fragmentType, payloadFragmentCount); +} + +SmallVector, 4> +buildProjectedFragmentOffsetsByDim(ArrayRef> fragmentOffsets, size_t rank) { + SmallVector, 4> fragmentOffsetsByDim(rank); + for (ArrayRef offsets : fragmentOffsets) { + assert(offsets.size() == rank && "projected offset rank mismatch"); + for (size_t dim = 0; dim < rank; ++dim) + fragmentOffsetsByDim[dim].push_back(offsets[dim]); + } + return fragmentOffsetsByDim; +} + +LogicalResult verifyProjectedTransferDescriptor(Operation* anchor, const ProjectedTransferDescriptor& descriptor) { + if (failed(verifyProjectedFragmentLayout(anchor, descriptor.layout))) + return failure(); + if (!descriptor.payloadType) + return anchor->emitError("projected transfer descriptor is missing payload type"); + if (descriptor.fragmentOffsets.empty()) + return anchor->emitError("projected transfer descriptor expected at least one fragment offset"); + if (descriptor.fragmentOffsetsByDim.size() != descriptor.layout.fragmentShape.size()) + return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent"); + for (ArrayRef dimOffsets : descriptor.fragmentOffsetsByDim) + if (dimOffsets.size() != descriptor.fragmentOffsets.size()) + return anchor->emitError("projected transfer descriptor dimension-major offsets are inconsistent"); + for (ArrayRef offsets : descriptor.fragmentOffsets) + if (offsets.size() != descriptor.layout.fragmentShape.size()) + return anchor->emitError("projected transfer offset rank does not match fragment rank"); + return success(); +} + +LogicalResult verifyProjectedSendDescriptor(Operation* anchor, + const ProjectedTransferDescriptor& descriptor, + const MessageVector& messages) { + if (failed(verifyProjectedTransferDescriptor(anchor, descriptor))) + return failure(); + if (messages.size() * descriptor.layout.payloadFragmentCount != descriptor.fragmentOffsets.size()) + return anchor->emitError("projected send descriptor metadata is inconsistent"); + return success(); +} + +LogicalResult finalizeProjectedTransferDescriptor(Operation* anchor, ProjectedTransferDescriptor& descriptor) { + descriptor.fragmentOffsetsByDim = + buildProjectedFragmentOffsetsByDim(descriptor.fragmentOffsets, descriptor.layout.fragmentShape.size()); + + FailureOr payloadType = + getProjectedPayloadType(anchor, descriptor.layout.fragmentType, descriptor.layout.payloadFragmentCount); + if (failed(payloadType)) + return failure(); + if (descriptor.payloadType && descriptor.payloadType != *payloadType) + return anchor->emitError("projected transfer descriptor payload type does not match projected layout"); + descriptor.payloadType = *payloadType; + + return verifyProjectedTransferDescriptor(anchor, descriptor); +} + +static FailureOr evaluateProjectedOffsetValue(OpFoldResult value, + Value laneArg, + uint32_t lane, + ArrayRef loops, + ArrayRef loopIterationIndices) { + if (std::optional constant = getConstantIndex(value)) + return *constant; + + Value current = dyn_cast(value); + if (!current) + return failure(); + if (current == laneArg) + return static_cast(lane); + + for (auto [index, loop] : llvm::enumerate(loops)) { + if (current != loop.iv) + continue; + if (index >= loopIterationIndices.size()) + return failure(); + return loop.lowerBound + loopIterationIndices[index] * loop.step; + } + + if (auto affineApply = current.getDefiningOp()) { + return evaluateAffineApply(affineApply, [&](Value operand) { + return evaluateProjectedOffsetValue(operand, laneArg, lane, loops, loopIterationIndices); + }); + } + + return failure(); +} + +static std::optional getConstantIndex(OpFoldResult value) { + if (auto attr = dyn_cast(value)) { + auto intAttr = dyn_cast(attr); + if (!intAttr) + return std::nullopt; + return intAttr.getInt(); + } + + Value operand = dyn_cast(value); + if (!operand) + return std::nullopt; + + if (auto constantIndex = operand.getDefiningOp()) + return constantIndex.value(); + + APInt apInt; + if (matchPattern(operand, m_ConstantInt(&apInt))) { + if (apInt.isNegative()) + return std::nullopt; + return static_cast(apInt.getSExtValue()); + } + + return std::nullopt; +} + +static std::optional matchAffineProjectedInputSlice(SpatComputeBatch batch, + unsigned inputIndex) { + const auto fail = [&](StringRef) -> std::optional { return std::nullopt; }; + + std::optional inputArg = batch.getInputArgument(inputIndex); + std::optional laneArg = batch.getLaneArgument(); + if (!inputArg || !laneArg) + return fail("missing-input-or-lane-arg"); + + if (!inputArg->hasOneUse()) + return fail("input-arg-not-one-use"); + + Operation* user = *inputArg->getUsers().begin(); + auto extract = dyn_cast(user); + if (!extract || extract.getSource() != *inputArg) + return fail("input-user-is-not-direct-extract-slice"); + + auto inputType = dyn_cast(inputArg->getType()); + auto fragmentType = dyn_cast(extract.getResult().getType()); + if (!inputType || !fragmentType || !inputType.hasStaticShape() || !fragmentType.hasStaticShape()) + return fail("non-static-ranked-input-or-fragment"); + + if (inputType.getRank() == 0 || inputType.getRank() != fragmentType.getRank()) + return fail("rank-mismatch-or-rank-zero"); + + SmallVector offsets = extract.getMixedOffsets(); + SmallVector sizes = extract.getMixedSizes(); + SmallVector strides = extract.getMixedStrides(); + + if (offsets.size() != static_cast(inputType.getRank()) + || sizes.size() != static_cast(inputType.getRank()) + || strides.size() != static_cast(inputType.getRank())) + return fail("slice-rank-mismatch"); + + SmallVector loops = collectEnclosingStaticProjectedLoops(extract.getOperation()); + if (extract->getParentOfType() && loops.empty()) + return fail("unsupported-enclosing-loop"); + + bool hasDynamicProjection = false; + for (auto [dim, offset] : llvm::enumerate(offsets)) { + bool usesDynamicBinding = false; + if (auto value = dyn_cast(offset)) { + if (!isProjectedOffsetValue(value, *laneArg, loops, usesDynamicBinding)) + return std::nullopt; + } + else if (!isa(offset)) + return std::nullopt; + if (std::optional stride = getConstantIndex(strides[dim]); !stride || *stride != 1) + return std::nullopt; + std::optional size = getConstantIndex(sizes[dim]); + if (!size || *size != fragmentType.getDimSize(dim)) + return std::nullopt; + hasDynamicProjection = hasDynamicProjection || usesDynamicBinding; + } + + if (!hasDynamicProjection) + return fail("no-dynamic-projection"); + + for (int64_t dim = 0; dim < inputType.getRank(); ++dim) + if (fragmentType.getDimSize(dim) <= 0 || fragmentType.getDimSize(dim) > inputType.getDimSize(dim)) + return std::nullopt; + + AffineProjectedInputSliceMatch match; + match.extract = extract; + match.sourceType = inputType; + match.fragmentType = fragmentType; + match.offsets.assign(offsets.begin(), offsets.end()); + match.fragmentShape.assign(fragmentType.getShape().begin(), fragmentType.getShape().end()); + match.loops = std::move(loops); + return match; +} + +std::optional +getProjectedInputSliceMatch(MaterializerState& state, SpatComputeBatch batch, unsigned inputIndex) { + ProjectedBatchInputKey key {batch.getOperation(), inputIndex}; + auto cached = state.projectedInputMatches.find(key); + if (cached != state.projectedInputMatches.end()) + return cached->second; + if (state.nonProjectedInputs.contains(key)) + return std::nullopt; + + std::optional match = matchAffineProjectedInputSlice(batch, inputIndex); + if (!match) { + state.nonProjectedInputs.insert(key); + return std::nullopt; + } + + state.projectedInputMatches.insert({key, *match}); + return match; +} + +FailureOr evaluateProjectionIndexLike(OpFoldResult value, Value laneArg, uint32_t lane); + +FailureOr evaluateProjectionIndexLike(Value value, Value laneArg, uint32_t lane) { + if (value == laneArg) + return static_cast(lane); + + if (std::optional constant = matchConstantIndexValue(value)) + return *constant; + + auto affineApply = value.getDefiningOp(); + if (!affineApply || affineApply.getAffineMap().getNumResults() != 1) + return failure(); + + SmallVector operands; + operands.reserve(affineApply.getMapOperands().size()); + for (Value operand : affineApply.getMapOperands()) { + FailureOr evaluated = evaluateProjectionIndexLike(operand, laneArg, lane); + if (failed(evaluated)) + return failure(); + operands.push_back(IntegerAttr::get(IndexType::get(value.getContext()), *evaluated)); + } + + SmallVector results; + if (failed(affineApply.getAffineMap().constantFold(operands, results)) || results.size() != 1) + return failure(); + + auto intAttr = dyn_cast(results.front()); + if (!intAttr) + return failure(); + return intAttr.getInt(); +} + +FailureOr evaluateProjectionIndexLike(OpFoldResult value, Value laneArg, uint32_t lane) { + if (auto attr = llvm::dyn_cast(value)) { + auto intAttr = dyn_cast(attr); + if (!intAttr) + return failure(); + return intAttr.getInt(); + } + return evaluateProjectionIndexLike(llvm::cast(value), laneArg, lane); +} + +FailureOr +getBatchResultProjectionInsert(SpatComputeBatch batch, size_t resultIndex) { + auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); + if (!inParallel) + return failure(); + + auto firstOutputArg = batch.getOutputArgument(0); + if (!firstOutputArg) + return failure(); + + for (Operation& op : inParallel.getRegion().front()) { + auto insert = dyn_cast(&op); + if (!insert) + continue; + + auto outputArg = dyn_cast(insert.getDest()); + if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) + continue; + + unsigned candidateIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); + if (candidateIndex == resultIndex) + return insert; + } + + return failure(); +} + +FailureOr> +evaluateStaticProjectionIndices(ArrayRef values, Value laneArg, uint32_t lane) { + SmallVector evaluated; + evaluated.reserve(values.size()); + for (OpFoldResult value : values) { + FailureOr index = evaluateProjectionIndexLike(value, laneArg, lane); + if (failed(index)) + return failure(); + evaluated.push_back(*index); + } + return evaluated; +} + + +bool isProjectedInputSliceCompatibleWithProducerFragments(SpatComputeBatch consumerBatch, + const AffineProjectedInputSliceMatch& match, + ProducerKey producer, + uint32_t consumerLane) { + auto producerBatch = dyn_cast_or_null(producer.instance.op); + if (!producerBatch) + return true; + + FailureOr producerProjection = + getBatchResultProjectionInsert(producerBatch, producer.resultIndex); + if (failed(producerProjection)) + return true; + + std::optional producerLaneArg = producerBatch.getLaneArgument(); + std::optional consumerLaneArg = consumerBatch.getLaneArgument(); + if (!producerLaneArg || !consumerLaneArg) + return false; + + SmallVector consumerSizes(match.fragmentShape.begin(), match.fragmentShape.end()); + SmallVector loopIterationIndices(match.loops.size(), 0); + + const auto consumerSliceFitsOneProducerFragment = [&]() -> bool { + SmallVector consumerOffsets; + consumerOffsets.reserve(match.offsets.size()); + for (OpFoldResult offset : match.offsets) { + FailureOr evaluated = + evaluateProjectedOffsetValue(offset, *consumerLaneArg, consumerLane, match.loops, loopIterationIndices); + if (failed(evaluated)) + return false; + consumerOffsets.push_back(*evaluated); + } + + uint32_t producerLaneEnd = producer.instance.laneStart + producer.instance.laneCount; + for (uint32_t producerLane = producer.instance.laneStart; producerLane < producerLaneEnd; ++producerLane) { + FailureOr> producerOffsets = + evaluateStaticProjectionIndices(producerProjection->getMixedOffsets(), *producerLaneArg, producerLane); + FailureOr> producerSizes = + evaluateStaticProjectionIndices(producerProjection->getMixedSizes(), *producerLaneArg, producerLane); + FailureOr> producerStrides = + evaluateStaticProjectionIndices(producerProjection->getMixedStrides(), *producerLaneArg, producerLane); + if (failed(producerOffsets) || failed(producerSizes) || failed(producerStrides)) + return false; + if (!areAllUnitStrides(*producerStrides)) + return false; + if (isStaticSliceContainedIn(consumerOffsets, consumerSizes, *producerOffsets, *producerSizes)) + return true; + } + + return false; + }; + + if (match.loops.empty()) + return consumerSliceFitsOneProducerFragment(); + + const auto recurse = [&](auto&& self, size_t loopIndex) -> bool { + if (loopIndex == match.loops.size()) + return consumerSliceFitsOneProducerFragment(); + + for (int64_t iteration = 0; iteration < match.loops[loopIndex].tripCount; ++iteration) { + loopIterationIndices[loopIndex] = iteration; + if (!self(self, loopIndex + 1)) + return false; + } + return true; + }; + + return recurse(recurse, 0); +} + + +LogicalResult collectProjectedTransfers(MaterializerState& state) { + struct PendingProjectedTransferDescriptor { + ProjectedBatchInputKey inputKey; + Operation* extractOp = nullptr; + RankedTensorType sourceType; + RankedTensorType fragmentType; + SmallVector fragmentShape; + SmallVector, 16>, 8> fragmentOffsetsByLane; + SmallVector loopLowerBounds; + SmallVector loopSteps; + SmallVector loopTripCounts; + bool invalid = false; + }; + + DenseMap, ProducerKeyInfo> pending; + + const auto isIdentityProjectedTransfer = [&](const PendingProjectedTransferDescriptor& descriptor) { + if (!descriptor.sourceType || descriptor.sourceType != descriptor.fragmentType) + return false; + + if (descriptor.fragmentOffsetsByLane.size() != 1) + return false; + + ArrayRef> fragments = descriptor.fragmentOffsetsByLane.front(); + if (fragments.size() != 1) + return false; + + return llvm::all_of(fragments.front(), [](int64_t offset) { return offset == 0; }); + }; + + const auto appendEvaluatedFragments = [&](PendingProjectedTransferDescriptor& descriptor, + unsigned targetLane, + const AffineProjectedInputSliceMatch& match, + Value laneArg, + uint32_t lane) -> LogicalResult { + SmallVector loopIterationIndices; + loopIterationIndices.resize(match.loops.size(), 0); + + const auto appendOneFragment = [&]() -> LogicalResult { + SmallVector evaluatedOffsets; + evaluatedOffsets.reserve(match.offsets.size()); + for (OpFoldResult offset : match.offsets) { + FailureOr evaluated = + evaluateProjectedOffsetValue(offset, laneArg, lane, match.loops, loopIterationIndices); + if (failed(evaluated)) + return failure(); + evaluatedOffsets.push_back(*evaluated); + } + + if (!isStaticSliceInBounds(evaluatedOffsets, match.sourceType, match.fragmentType)) + return failure(); + + descriptor.fragmentOffsetsByLane[targetLane].push_back(std::move(evaluatedOffsets)); + return success(); + }; + + if (match.loops.empty()) + return appendOneFragment(); + + const auto recurse = [&](auto&& self, size_t loopIndex) -> LogicalResult { + if (loopIndex == match.loops.size()) + return appendOneFragment(); + + for (int64_t iteration = 0; iteration < match.loops[loopIndex].tripCount; ++iteration) { + loopIterationIndices[loopIndex] = iteration; + if (failed(self(self, loopIndex + 1))) + return failure(); + } + return success(); + }; + + return recurse(recurse, 0); + }; + + if (failed(forEachLogicalConsumerInMaterializationOrder( + state, + [&](CpuId cpu, + ClassId targetClassId, + ComputeInstance consumer, + ComputeInstance logicalConsumer, + SlotId logicalSlot) -> LogicalResult { + auto batch = dyn_cast(consumer.op); + if (!batch) + return success(); + + MaterializedClass& targetClass = state.classes[targetClassId]; + unsigned targetLane = 0; + if (targetClass.isBatch) { + auto targetLaneIt = targetClass.cpuToLane.find(cpu); + if (targetLaneIt == targetClass.cpuToLane.end()) + return consumer.op->emitError("projected transfer collection could not recover target lane"); + targetLane = targetLaneIt->second; + } + + for (auto [inputIndex, input] : llvm::enumerate(batch.getInputs())) { + SmallVector producers = collectProducerKeysForDestinations(input, logicalConsumer); + if (producers.size() != 1) + continue; + ProducerKey producer = producers.front(); + + ComputeInstance scheduledProducer = getScheduledChunkForLogicalInstance(state, producer.instance); + auto producerCpuIt = state.schedule.computeToCpuMap.find(scheduledProducer); + if (producerCpuIt == state.schedule.computeToCpuMap.end()) + continue; + + ClassId sourceClassId = state.cpuToClass.lookup(producerCpuIt->second); + if (sourceClassId == targetClassId) + continue; + + std::optional match = + getProjectedInputSliceMatch(state, batch, static_cast(inputIndex)); + if (!match) + continue; + if (!isProjectedInputSliceCompatibleWithProducerFragments( + batch, *match, producer, logicalConsumer.laneStart)) + continue; + + PendingProjectedTransferDescriptor& descriptor = pending[producer][targetClassId]; + if (descriptor.fragmentOffsetsByLane.empty()) { + descriptor.inputKey = {batch.getOperation(), static_cast(inputIndex)}; + descriptor.extractOp = match->extract.getOperation(); + descriptor.sourceType = match->sourceType; + descriptor.fragmentType = match->fragmentType; + descriptor.fragmentShape = match->fragmentShape; + descriptor.fragmentOffsetsByLane.resize(targetClass.isBatch ? targetClass.cpus.size() : 1); + descriptor.loopLowerBounds.reserve(match->loops.size()); + descriptor.loopSteps.reserve(match->loops.size()); + descriptor.loopTripCounts.reserve(match->loops.size()); + for (const StaticProjectedLoopInfo& loop : match->loops) { + descriptor.loopLowerBounds.push_back(loop.lowerBound); + descriptor.loopSteps.push_back(loop.step); + descriptor.loopTripCounts.push_back(loop.tripCount); + } + } + + ProjectedBatchInputKey currentInputKey {batch.getOperation(), static_cast(inputIndex)}; + if (!(descriptor.inputKey == currentInputKey) || descriptor.extractOp != match->extract.getOperation() + || descriptor.sourceType != match->sourceType || descriptor.fragmentType != match->fragmentType + || descriptor.fragmentShape != match->fragmentShape + || descriptor.loopLowerBounds.size() != match->loops.size()) { + descriptor.invalid = true; + continue; + } + for (auto [index, loop] : llvm::enumerate(match->loops)) { + if (descriptor.loopLowerBounds[index] != loop.lowerBound || descriptor.loopSteps[index] != loop.step + || descriptor.loopTripCounts[index] != loop.tripCount) { + descriptor.invalid = true; + break; + } + } + if (descriptor.invalid) + continue; + + if (targetLane >= descriptor.fragmentOffsetsByLane.size()) { + descriptor.invalid = true; + continue; + } + + if (failed(appendEvaluatedFragments( + descriptor, targetLane, *match, *batch.getLaneArgument(), logicalConsumer.laneStart))) { + descriptor.invalid = true; + continue; + } + + (void) logicalSlot; + } + + return success(); + }))) + return failure(); + + for (auto& producerEntry : pending) { + ProducerKey producer = producerEntry.first; + for (auto& classEntry : producerEntry.second) { + ClassId targetClassId = classEntry.first; + PendingProjectedTransferDescriptor& pendingDescriptor = classEntry.second; + + if (pendingDescriptor.invalid) + continue; + if (pendingDescriptor.fragmentOffsetsByLane.empty()) + continue; + if (isIdentityProjectedTransfer(pendingDescriptor)) + continue; + + MaterializedClass& targetClass = state.classes[targetClassId]; + ProjectedTransferDescriptor descriptor; + descriptor.inputKey = pendingDescriptor.inputKey; + descriptor.extractOp = pendingDescriptor.extractOp; + descriptor.layout.fragmentType = pendingDescriptor.fragmentType; + descriptor.layout.fragmentShape = pendingDescriptor.fragmentShape; + descriptor.layout.loopLowerBounds = pendingDescriptor.loopLowerBounds; + descriptor.layout.loopSteps = pendingDescriptor.loopSteps; + descriptor.layout.loopTripCounts = pendingDescriptor.loopTripCounts; + descriptor.layout.fragmentsPerLogicalSlot = getProjectedFragmentsPerLogicalSlot(descriptor.layout.loopTripCounts); + if (targetClass.isBatch) { + unsigned payloadFragmentCount = pendingDescriptor.fragmentOffsetsByLane.front().size(); + if (payloadFragmentCount == 0) + continue; + + // Batch-target projected replacements currently select fragments with the + // local materialization-run slot index. That is only unambiguous when each + // target lane receives one projected fragment. Multi-fragment payloads + // need an explicit producer-key to payload-slot mapping; otherwise two + // independently materialized runs can both select fragment 0 from the same + // packed receive and duplicate rows. + if (payloadFragmentCount != 1) + continue; + + bool uniform = true; + for (ArrayRef> laneFragments : pendingDescriptor.fragmentOffsetsByLane) { + if (laneFragments.size() != payloadFragmentCount) { + uniform = false; + break; + } + } + if (!uniform) + continue; + + descriptor.layout.payloadFragmentCount = payloadFragmentCount; + descriptor.fragmentOffsets.reserve(pendingDescriptor.fragmentOffsetsByLane.size() * payloadFragmentCount); + for (ArrayRef> laneFragments : pendingDescriptor.fragmentOffsetsByLane) + llvm::append_range(descriptor.fragmentOffsets, laneFragments); + } + else { + if (pendingDescriptor.fragmentOffsetsByLane.size() != 1) + return targetClass.op->emitError("scalar projected transfer descriptor expected one local offset stream"); + if (pendingDescriptor.fragmentOffsetsByLane.front().empty()) + continue; + + descriptor.layout.payloadFragmentCount = pendingDescriptor.fragmentOffsetsByLane.front().size(); + llvm::append_range(descriptor.fragmentOffsets, pendingDescriptor.fragmentOffsetsByLane.front()); + if (descriptor.fragmentOffsets.size() != descriptor.layout.payloadFragmentCount) + return targetClass.op->emitError("scalar projected transfer offset count does not match the local run"); + } + if (failed(finalizeProjectedTransferDescriptor(targetClass.op, descriptor))) + return failure(); + + state.projectedTransfers[producer][targetClassId] = std::move(descriptor); + } + } + + return success(); +} + +static std::optional +collectScalarTargetProjectedDescriptor(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef keys, + bool requirePackedRunOffsetCountMatch) { + assert(!targetClass.isBatch && "scalar target projected descriptor helper expects a scalar class"); + + std::optional combined; + for (ProducerKey key : keys) { + auto producerIt = state.projectedTransfers.find(key); + if (producerIt == state.projectedTransfers.end()) + return std::nullopt; + + auto descriptorIt = producerIt->second.find(targetClass.id); + if (descriptorIt == producerIt->second.end()) + return std::nullopt; + + const ProjectedTransferDescriptor& descriptor = descriptorIt->second; + if (descriptor.fragmentOffsets.empty()) + return std::nullopt; + if (descriptor.layout.payloadFragmentCount == 0 || descriptor.layout.fragmentsPerLogicalSlot == 0) + return std::nullopt; + if (descriptor.fragmentOffsets.size() != descriptor.layout.payloadFragmentCount) + return std::nullopt; + if (descriptor.layout.payloadFragmentCount % descriptor.layout.fragmentsPerLogicalSlot != 0) + return std::nullopt; + + if (!combined) { + combined = descriptor; + continue; + } + + if (!(combined->inputKey == descriptor.inputKey) || combined->extractOp != descriptor.extractOp + || combined->layout.fragmentType != descriptor.layout.fragmentType + || combined->layout.fragmentShape != descriptor.layout.fragmentShape + || combined->layout.loopLowerBounds != descriptor.layout.loopLowerBounds + || combined->layout.loopSteps != descriptor.layout.loopSteps + || combined->layout.loopTripCounts != descriptor.layout.loopTripCounts + || combined->layout.fragmentsPerLogicalSlot != descriptor.layout.fragmentsPerLogicalSlot) + return std::nullopt; + + combined->layout.payloadFragmentCount += descriptor.layout.payloadFragmentCount; + llvm::append_range(combined->fragmentOffsets, descriptor.fragmentOffsets); + } + + if (!combined) + return std::nullopt; + + if (combined->fragmentOffsets.size() != combined->layout.payloadFragmentCount) + return std::nullopt; + + if (requirePackedRunOffsetCountMatch) { + if (combined->layout.payloadFragmentCount != keys.size() * combined->layout.fragmentsPerLogicalSlot) + return std::nullopt; + } + if (failed(finalizeProjectedTransferDescriptor(targetClass.op, *combined))) + return std::nullopt; + return combined; +} + +bool haveSameDestinationClasses(MaterializerState& state, ArrayRef keys) { + if (keys.empty()) + return true; + + auto firstIt = state.producerDestClasses.find(keys.front()); + ArrayRef first = firstIt == state.producerDestClasses.end() ? ArrayRef() : firstIt->second; + for (ProducerKey key : keys.drop_front()) { + auto it = state.producerDestClasses.find(key); + ArrayRef current = it == state.producerDestClasses.end() ? ArrayRef() : it->second; + if (first.size() != current.size()) + return false; + for (auto [lhs, rhs] : llvm::zip(first, current)) + if (lhs != rhs) + return false; + } + return true; +} + +ArrayRef getDestinationClasses(MaterializerState& state, ProducerKey key) { + auto it = state.producerDestClasses.find(key); + if (it == state.producerDestClasses.end()) + return {}; + return it->second; +} + +// ----------------------------------------------------------------------------- +// Communication materialization helpers. +// ----------------------------------------------------------------------------- + +void appendScalarSend(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + int64_t channelId, + int32_t sourceCoreId, + int32_t targetCoreId, + Location loc) { + assert(!sourceClass.isBatch && "scalar send helper expects a scalar source class"); + + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, channelId); + Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, sourceCoreId); + Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, targetCoreId); + SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload); +} + +LogicalResult appendScalarSendLoop(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const MessageVector& messages, + Location loc) { + assert(!sourceClass.isBatch && "scalar send loop expects a scalar source class"); + assert(messages.size() > 1 && "send loop is only useful for multiple sends"); + assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); + + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); + Value upperBound = + getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(messages.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); + + auto sendLoop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {}, + [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { + Value channelId = createIndexedChannelId(state, sourceClass.op, messages, index, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, messages, index, loc); + Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, messages, index, loc); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); + return success(); + }); + if (failed(sendLoop)) + return failure(); + return success(); +} + +FailureOr buildProjectedPackedPayload(MaterializerState& state, + MaterializedClass& targetClass, + Value fullPayload, + const ProjectedTransferDescriptor& descriptor, + Value messageIndex, + Location loc) { + if (failed(verifyProjectedTransferDescriptor(targetClass.op, descriptor))) + return failure(); + if (descriptor.layout.payloadFragmentCount == 1) + return targetClass.op->emitError("projected packed payload builder expects a packed payload"); + + FailureOr localizedPayload = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + fullPayload, + targetClass.op, + "projected packed payload tried to reuse a tensor from another materialized class"); + if (failed(localizedPayload)) + return failure(); + FailureOr localizedMessageIndex = rematerializeIndexValueInClass(state, targetClass, messageIndex, loc); + if (failed(localizedMessageIndex)) + return failure(); + + Value init = tensor::EmptyOp::create( + state.rewriter, loc, descriptor.payloadType.getShape(), descriptor.payloadType.getElementType()) + .getResult(); + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, descriptor.layout.payloadFragmentCount); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {init}, + [&](OpBuilder&, Location, Value fragmentIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value acc = iterArgs.front(); + Value payloadFragmentCount = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, descriptor.layout.payloadFragmentCount); + Value flatBase = arith::MulIOp::create(state.rewriter, loc, *localizedMessageIndex, payloadFragmentCount).getResult(); + Value flatIndex = arith::AddIOp::create(state.rewriter, loc, flatBase, fragmentIndex).getResult(); + + FailureOr> fragmentOffsets = + buildProjectedFragmentOffsetsInClass(state, targetClass, descriptor, flatIndex, loc); + if (failed(fragmentOffsets)) + return failure(); + FailureOr fragment = createStaticExtractSliceInClass( + state, targetClass, loc, *localizedPayload, *fragmentOffsets, descriptor.layout.fragmentShape); + if (failed(fragment)) + return failure(); + + FailureOr packedOffset = + scaleIndexByDim0SizeInClass(state, targetClass, fragmentIndex, descriptor.layout.fragmentType.getDimSize(0), loc); + if (failed(packedOffset)) + return failure(); + FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, *fragment, acc, *packedOffset); + if (failed(next)) + return failure(); + yielded.push_back(*next); + return success(); + }); + if (failed(loop)) + return failure(); + return loop->results.front(); +} + +FailureOr buildProjectedPayloadForMessage(MaterializerState& state, + MaterializedClass& targetClass, + Value fullPayload, + const ProjectedTransferDescriptor& descriptor, + Value messageIndex, + Location loc) { + if (failed(verifyProjectedTransferDescriptor(targetClass.op, descriptor))) + return failure(); + + FailureOr localizedPayload = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + fullPayload, + targetClass.op, + "projected payload tried to reuse a tensor from another materialized class"); + if (failed(localizedPayload)) + return failure(); + + if (descriptor.layout.payloadFragmentCount == 1) { + FailureOr> fragmentOffsets = + buildProjectedFragmentOffsetsInClass(state, targetClass, descriptor, messageIndex, loc); + if (failed(fragmentOffsets)) + return failure(); + return createStaticExtractSliceInClass( + state, targetClass, loc, *localizedPayload, *fragmentOffsets, descriptor.layout.fragmentShape); + } + + return buildProjectedPackedPayload(state, targetClass, *localizedPayload, descriptor, messageIndex, loc); +} + +LogicalResult appendProjectedScalarSendLoop(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const ProjectedTransferDescriptor& descriptor, + const MessageVector& messages, + Location loc) { + assert(!sourceClass.isBatch && "projected scalar send expects scalar source class"); + assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); + if (failed(verifyProjectedSendDescriptor(sourceClass.op, descriptor, messages))) + return failure(); + + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + + if (messages.size() == 1) { + Value channelId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.channelIds.front()); + Value sourceCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.sourceCoreIds.front()); + Value targetCoreId = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, messages.targetCoreIds.front()); + Value messageIndex = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); + FailureOr sendPayload = + buildProjectedPayloadForMessage(state, sourceClass, payload, descriptor, messageIndex, loc); + if (failed(sendPayload)) + return failure(); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, *sendPayload); + return success(); + } + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 0); + Value upperBound = + getOrCreateIndexConstant(state.constantFolder, sourceClass.op, static_cast(messages.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, sourceClass.op, 1); + + auto projectedSendLoop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {}, + [&](OpBuilder&, Location, Value index, ValueRange, SmallVectorImpl&) { + Value channelId = createIndexedChannelId(state, sourceClass.op, messages, index, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, messages, index, loc); + Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, messages, index, loc); + FailureOr sendPayload = + buildProjectedPayloadForMessage(state, sourceClass, payload, descriptor, index, loc); + if (failed(sendPayload)) + return failure(); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, *sendPayload); + return success(); + }); + if (failed(projectedSendLoop)) + return failure(); + return success(); +} + +LogicalResult appendSend(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const MessageVector& messages, + Location loc) { + assert(succeeded(messages.verify(sourceClass.op)) && "message metadata is inconsistent"); + assert(!messages.empty() && "expected at least one send"); + + if (sourceClass.isBatch) { + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + + Value channelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc); + Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc); + Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.targetCoreIds, loc); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); + return success(); + } + + if (messages.size() == 1) { + appendScalarSend(state, + sourceClass, + payload, + messages.channelIds.front(), + messages.sourceCoreIds.front(), + messages.targetCoreIds.front(), + loc); + return success(); + } + + return appendScalarSendLoop(state, sourceClass, payload, messages, loc); +} + +Value appendScalarReceive(MaterializerState& state, + MaterializedClass& targetClass, + Type type, + int64_t channelId, + int32_t sourceCoreId, + int32_t targetCoreId, + Location loc) { + assert(!targetClass.isBatch && "scalar receive helper expects a scalar target class"); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value channelIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, channelId); + Value sourceCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, sourceCoreId); + Value targetCoreIdValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, targetCoreId); + return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue) + .getOutput(); +} + +Value appendReceive( + MaterializerState& state, MaterializedClass& targetClass, Type type, const MessageVector& messages, Location loc) { + assert(succeeded(messages.verify(targetClass.op)) && "message metadata is inconsistent"); + assert(!messages.empty() && "expected at least one receive"); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + + if (targetClass.isBatch) { + Value channelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc); + Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc); + Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc); + return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelId, sourceCoreId, targetCoreId).getOutput(); + } + + assert(messages.size() == 1 && "scalar target class can only receive one message at a time"); + return appendScalarReceive(state, + targetClass, + type, + messages.channelIds.front(), + messages.sourceCoreIds.front(), + messages.targetCoreIds.front(), + loc); +} + +LogicalResult registerLazyPackedScalarReceives(MaterializerState& state, + MaterializedClass& sourceClass, + MaterializedClass& targetClass, + ArrayRef keys, + Type fragmentType, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds) { + if (!sourceClass.isBatch) + return sourceClass.op->emitError("lazy packed scalar receives expect a batch source class"); + + if (targetClass.isBatch) + return targetClass.op->emitError("lazy packed scalar receives expect a scalar target class"); + + if (keys.empty()) + return sourceClass.op->emitError("lazy packed scalar receive expects at least one producer key"); + + if (keys.size() != sourceClass.cpus.size()) + return sourceClass.op->emitError("lazy packed scalar receive expects one producer key per source lane"); + + MessageVector messages; + messages.append(channelIds, sourceCoreIds, targetCoreIds); + if (failed(messages.verify(targetClass.op))) + return failure(); + + if (keys.size() != messages.size()) + return targetClass.op->emitError("lazy packed scalar receive metadata is inconsistent"); + + auto rankedFragmentType = dyn_cast(fragmentType); + if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) + return targetClass.op->emitError("lazy packed scalar receive expects a static ranked fragment type"); + + if (failed(verifyPackableFragmentType( + targetClass.op, fragmentType, keys.size(), "cannot create lazy packed scalar receive type"))) + return failure(); + + Operation* sourceOp = keys.front().instance.op; + size_t resultIndex = keys.front().resultIndex; + + for (ProducerKey key : keys) { + if (key.instance.op != sourceOp || key.resultIndex != resultIndex) + return sourceClass.op->emitError("lazy packed scalar receive expects one producer result"); + + if (key.instance.laneCount != 1) + return sourceClass.op->emitError("lazy packed scalar receive expects one lane per producer key"); + } + + PackedScalarRunValue packedRun; + packedRun.targetClass = targetClass.id; + packedRun.sourceOp = sourceOp; + packedRun.resultIndex = resultIndex; + packedRun.kind = PackedScalarRunKind::DeferredReceive; + packedRun.fragmentType = rankedFragmentType; + + packedRun.messages = std::move(messages); + + PackedScalarRunSlot slot; + llvm::append_range(slot.keys, keys); + packedRun.slots.push_back(std::move(slot)); + + if (failed(validatePackedScalarRunMetadata(targetClass.op, packedRun))) + return failure(); + + state.availableValues.recordPackedRun(std::move(packedRun)); + return success(); +} + +struct ScalarSourceReceivePlan { + ClassId targetClass = 0; + MessageVector messages; + Type receiveType; + Operation* projectedExtractOp = nullptr; + ProjectedFragmentLayout projectedLayout; +}; + +struct ProjectedScalarSendGroup { + MessageVector messages; + ProjectedTransferDescriptor descriptor; +}; + +struct ScalarSourceFanoutPlan { + SmallVector receivePlans; + std::optional ordinaryMessages; + SmallVector projectedSendGroups; +}; + +bool hasSameProjectedSendCompatibility(const ProjectedTransferDescriptor& lhs, const ProjectedTransferDescriptor& rhs) { + return lhs.layout.fragmentType == rhs.layout.fragmentType && lhs.layout.fragmentShape == rhs.layout.fragmentShape + && lhs.layout.fragmentsPerLogicalSlot == rhs.layout.fragmentsPerLogicalSlot + && lhs.layout.payloadFragmentCount == rhs.layout.payloadFragmentCount + && lhs.layout.loopLowerBounds == rhs.layout.loopLowerBounds && lhs.layout.loopSteps == rhs.layout.loopSteps + && lhs.layout.loopTripCounts == rhs.layout.loopTripCounts && lhs.payloadType == rhs.payloadType; +} + +SmallVector collectDestinationClassesForKeys(MaterializerState& state, ArrayRef keys) { + SmallVector destinations; + + for (ProducerKey key : keys) + for (ClassId destinationClass : getDestinationClasses(state, key)) + destinations.push_back(destinationClass); + + llvm::sort(destinations); + destinations.erase(std::unique(destinations.begin(), destinations.end()), destinations.end()); + return destinations; +} + +FailureOr buildScalarSourceFanoutPlan(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + ArrayRef destinationClasses, + Value payload) { + assert(!sourceClass.isBatch && "scalar-source send planning expects a scalar source class"); + + auto sourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "scalar source core id"); + if (failed(sourceCpu)) + return failure(); + + ScalarSourceFanoutPlan fanoutPlan; + fanoutPlan.receivePlans.reserve(destinationClasses.size()); + + const auto getProjectedDescriptor = + [&](ClassId destinationClass) -> FailureOr> { + MaterializedClass& targetClass = state.classes[destinationClass]; + if (!targetClass.isBatch) { + bool hasAnyProjectedDescriptor = llvm::any_of(keys, [&](ProducerKey key) { + auto producerIt = state.projectedTransfers.find(key); + return producerIt != state.projectedTransfers.end() && producerIt->second.count(destinationClass) != 0; + }); + + std::optional descriptor = collectScalarTargetProjectedDescriptor( + state, targetClass, keys, /*requirePackedRunOffsetCountMatch=*/keys.size() > 1); + if (hasAnyProjectedDescriptor && !descriptor) + return targetClass.op->emitError("incomplete scalar projected transfer descriptor for local run"); + return descriptor; + } + + if (keys.size() != 1) + return std::optional {}; + + auto producerIt = state.projectedTransfers.find(keys.front()); + if (producerIt == state.projectedTransfers.end()) + return std::optional {}; + + auto descriptorIt = producerIt->second.find(destinationClass); + if (descriptorIt == producerIt->second.end()) + return std::optional {}; + + const ProjectedTransferDescriptor& descriptor = descriptorIt->second; + if (failed(verifyProjectedTransferDescriptor(targetClass.op, descriptor))) + return failure(); + if (descriptor.fragmentOffsets.size() + != targetClass.cpus.size() * static_cast(descriptor.layout.payloadFragmentCount)) + return targetClass.op->emitError("inconsistent batch projected transfer descriptor"); + + return std::optional {descriptor}; + }; + + for (ClassId destinationClass : destinationClasses) { + if (destinationClass == sourceClass.id) + continue; + + MaterializedClass& targetClass = state.classes[destinationClass]; + + ScalarSourceReceivePlan receivePlan; + receivePlan.targetClass = destinationClass; + receivePlan.receiveType = payload.getType(); + + auto appendMessage = [&](CpuId targetCpu) -> LogicalResult { + auto checkedTargetCpu = getCheckedCoreId(targetClass.op, targetCpu, "scalar target core id"); + if (failed(checkedTargetCpu)) + return failure(); + int64_t channelId = state.nextChannelId++; + + receivePlan.messages.append(channelId, *sourceCpu, *checkedTargetCpu); + return success(); + }; + + if (!targetClass.isBatch) { + if (failed(appendMessage(targetClass.cpus.front()))) + return failure(); + } + else { + for (CpuId targetCpu : targetClass.cpus) + if (failed(appendMessage(targetCpu))) + return failure(); + } + + FailureOr> descriptor = getProjectedDescriptor(destinationClass); + if (failed(descriptor)) + return failure(); + + if (*descriptor) { + const ProjectedTransferDescriptor& projectedDescriptor = **descriptor; + + if (!targetClass.isBatch && projectedDescriptor.payloadType == payload.getType()) + return targetClass.op->emitError("scalar projected receive unexpectedly uses the full producer tensor type"); + + receivePlan.receiveType = projectedDescriptor.payloadType; + receivePlan.projectedExtractOp = projectedDescriptor.extractOp; + receivePlan.projectedLayout = projectedDescriptor.layout; + + auto groupIt = llvm::find_if(fanoutPlan.projectedSendGroups, [&](const ProjectedScalarSendGroup& group) { + return hasSameProjectedSendCompatibility(group.descriptor, projectedDescriptor); + }); + if (groupIt == fanoutPlan.projectedSendGroups.end()) { + ProjectedScalarSendGroup group; + group.descriptor.layout = projectedDescriptor.layout; + group.descriptor.payloadType = projectedDescriptor.payloadType; + fanoutPlan.projectedSendGroups.push_back(std::move(group)); + groupIt = std::prev(fanoutPlan.projectedSendGroups.end()); + } + + groupIt->messages.append( + receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds); + llvm::append_range(groupIt->descriptor.fragmentOffsets, projectedDescriptor.fragmentOffsets); + } + else { + if (!fanoutPlan.ordinaryMessages) + fanoutPlan.ordinaryMessages = MessageVector {}; + fanoutPlan.ordinaryMessages->append( + receivePlan.messages.channelIds, receivePlan.messages.sourceCoreIds, receivePlan.messages.targetCoreIds); + } + + fanoutPlan.receivePlans.push_back(std::move(receivePlan)); + } + + for (ProjectedScalarSendGroup& group : fanoutPlan.projectedSendGroups) { + if (failed(finalizeProjectedTransferDescriptor(sourceClass.op, group.descriptor))) + return failure(); + if (failed(verifyProjectedSendDescriptor(sourceClass.op, group.descriptor, group.messages))) + return failure(); + } + + return fanoutPlan; +} + +LogicalResult emitScalarSourceFanoutSends(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const ScalarSourceFanoutPlan& plan, + Location loc) { + if (plan.ordinaryMessages && failed(appendSend(state, sourceClass, payload, *plan.ordinaryMessages, loc))) + return failure(); + + for (const ProjectedScalarSendGroup& group : plan.projectedSendGroups) + if (failed(appendProjectedScalarSendLoop(state, sourceClass, payload, group.descriptor, group.messages, loc))) + return failure(); + + return success(); +} + +LogicalResult emitScalarSourceCommunication( + MaterializerState& state, MaterializedClass& sourceClass, ArrayRef keys, Value payload, Location loc) { + assert(!sourceClass.isBatch && "scalar-source communication expects a scalar source class"); + + for (ProducerKey key : keys) + state.availableValues.record(key, sourceClass.id, payload); + + SmallVector destinationClasses = collectDestinationClassesForKeys(state, keys); + auto fanoutPlan = buildScalarSourceFanoutPlan(state, sourceClass, keys, destinationClasses, payload); + if (failed(fanoutPlan)) + return failure(); + if (failed(emitScalarSourceFanoutSends(state, sourceClass, payload, *fanoutPlan, loc))) + return failure(); + + for (const ScalarSourceReceivePlan& plan : fanoutPlan->receivePlans) { + MaterializedClass& targetClass = state.classes[plan.targetClass]; + + Value received = appendReceive(state, targetClass, plan.receiveType, plan.messages, loc); + + if (plan.projectedExtractOp) { + state.projectedExtractReplacements[plan.projectedExtractOp][plan.targetClass] = + ProjectedExtractReplacement {received, plan.projectedLayout}; + continue; + } + + for (ProducerKey key : keys) + state.availableValues.record(key, targetClass.id, received); + } + + return success(); +} + +LogicalResult emitClassToClassCommunication(MaterializerState& state, + MaterializedClass& sourceClass, + MaterializedClass& targetClass, + ArrayRef keys, + Value payload, + Location loc) { + if (sourceClass.id == targetClass.id) { + for (ProducerKey key : keys) + state.availableValues.record(key, targetClass.id, payload); + return success(); + } + + if (!sourceClass.isBatch) + return sourceClass.op->emitError("scalar-source communication must be emitted through the scalar fanout planner"); + + if (!targetClass.isBatch) { + if (keys.size() != sourceClass.cpus.size()) + return sourceClass.op->emitError( + "cannot materialize batch-to-scalar communication without one producer key per source lane") + << " keyCount=" << keys.size() << " laneCount=" << sourceClass.cpus.size(); + + Operation* sourceOp = keys.front().instance.op; + size_t sourceResultIndex = keys.front().resultIndex; + for (ProducerKey key : keys) { + if (key.instance.op != sourceOp || key.resultIndex != sourceResultIndex || key.instance.laneCount != 1) + return sourceClass.op->emitError( + "cannot materialize batch-to-scalar communication for incompatible source keys"); + } + + MessageVector messages; + messages.channelIds.reserve(sourceClass.cpus.size()); + messages.sourceCoreIds.reserve(sourceClass.cpus.size()); + messages.targetCoreIds.reserve(sourceClass.cpus.size()); + + auto targetCpu = getCheckedCoreId(targetClass.op, targetClass.cpus.front(), "batch-to-scalar target core id"); + if (failed(targetCpu)) + return failure(); + for (CpuId sourceCpu : sourceClass.cpus) { + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch-to-scalar source core id"); + if (failed(checkedSourceCpu)) + return failure(); + messages.append(state.nextChannelId++, *checkedSourceCpu, *targetCpu); + } + + if (failed(appendSend(state, sourceClass, payload, messages, loc))) + return failure(); + return registerLazyPackedScalarReceives(state, + sourceClass, + targetClass, + keys, + payload.getType(), + messages.channelIds, + messages.sourceCoreIds, + messages.targetCoreIds); + } + + if (sourceClass.cpus.size() != targetClass.cpus.size()) + return sourceClass.op->emitError( + "cannot materialize batch communication between equivalence classes of different sizes"); + + MessageVector messages; + messages.channelIds.reserve(sourceClass.cpus.size()); + messages.sourceCoreIds.reserve(sourceClass.cpus.size()); + messages.targetCoreIds.reserve(targetClass.cpus.size()); + + for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch source core id"); + if (failed(checkedSourceCpu)) + return failure(); + auto checkedTargetCpu = getCheckedCoreId(targetClass.op, targetClass.cpus[lane], "batch target core id"); + if (failed(checkedTargetCpu)) + return failure(); + messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); + } + + if (failed(appendSend(state, sourceClass, payload, messages, loc))) + return failure(); + Value received = appendReceive(state, targetClass, payload.getType(), messages, loc); + + for (ProducerKey key : keys) + state.availableValues.record(key, targetClass.id, received); + + return success(); +} + +FailureOr recordProjectedBatchHostFragmentsFromBatchValue(MaterializerState& state, + MaterializedClass& sourceClass, + MaterializedClass& ownerClass, + ArrayRef keys, + Value payload, + Value originalOutput, + Location loc); + +LogicalResult +setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Value originalOutput, Value payload) { + auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput); + if (resultIt == sourceClass.hostOutputToResultIndex.end()) + return sourceClass.op->emitError("missing host result slot for materialized output") + << " ownerKind=" << (sourceClass.isBatch ? "batch" : "scalar") + << " hostOutputs=" << sourceClass.hostOutputs.size() + << " originalDef=" << (originalOutput.getDefiningOp() ? originalOutput.getDefiningOp()->getName().getStringRef() + : StringRef("")); + + unsigned resultIndex = resultIt->second; + state.hostReplacements[originalOutput] = sourceClass.op->getResult(resultIndex); + + if (!sourceClass.isBatch) { + auto yieldOp = dyn_cast(sourceClass.body->getTerminator()); + if (!yieldOp) + return sourceClass.op->emitError("expected spat.yield terminator in materialized compute"); + if (resultIndex >= yieldOp.getNumOperands()) + return sourceClass.op->emitError("host result index out of range for materialized compute"); + if (payload.getType() != originalOutput.getType()) + return sourceClass.op->emitError("cannot set scalar host output from fragment payload") + << " payloadType=" << payload.getType() << " outputType=" << originalOutput.getType(); + + state.rewriter.modifyOpInPlace(yieldOp, [&] { yieldOp->setOperand(resultIndex, payload); }); + return success(); + } + + auto batch = cast(sourceClass.op); + auto inParallelOp = dyn_cast(sourceClass.body->getTerminator()); + if (!inParallelOp) + return sourceClass.op->emitError("expected spat.in_parallel terminator in materialized compute_batch"); + + auto payloadType = dyn_cast(payload.getType()); + if (!payloadType || !payloadType.hasStaticShape()) + return sourceClass.op->emitError("host-facing compute_batch payload must be a static ranked tensor"); + + auto laneArg = batch.getLaneArgument(); + if (!laneArg) + return batch.emitOpError("expected compute_batch lane block argument while materializing batch output"); + + auto outputArg = batch.getOutputArgument(resultIndex); + if (!outputArg) + return batch.emitOpError("expected compute_batch output block argument while materializing batch output"); + + state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); + createDim0ParallelInsertSlice(state, payload.getLoc(), payload, *outputArg, *laneArg); + return success(); +} + +LogicalResult +emitHostCommunication(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + Value originalOutput, + ArrayRef keys = {}) { + if (!hasLiveExternalUseCached(state, originalOutput)) + return success(); + + auto ownerIt = state.hostOutputOwners.find(originalOutput); + if (ownerIt == state.hostOutputOwners.end()) + return sourceClass.op->emitError("missing host owner for live external output"); + + MaterializedClass& ownerClass = state.classes[ownerIt->second]; + if (sourceClass.id == ownerClass.id) + return setHostOutputValue(state, ownerClass, originalOutput, payload); + + // Keep the old deadlock-free communication discipline: only scalar-to-scalar + // host-owner forwarding is introduced here. Batch host publication remains on + // the owning batch path; projected terminal batch publication must use the + // explicit projected whole-batch path instead of generic host forwarding. + if (sourceClass.isBatch) { + FailureOr recordedProjectedHostFragments = recordProjectedBatchHostFragmentsFromBatchValue( + state, sourceClass, ownerClass, keys, payload, originalOutput, payload.getLoc()); + if (failed(recordedProjectedHostFragments)) + return failure(); + if (*recordedProjectedHostFragments) + return success(); + return sourceClass.op->emitError("batch host publication must be routed through the owning/projection-aware path"); + } + if (ownerClass.isBatch) + return ownerClass.op->emitError("generic host publication does not support batch host owners"); + if (payload.getType() != originalOutput.getType()) + return sourceClass.op->emitError("cannot forward fragment payload to scalar host owner") + << " payloadType=" << payload.getType() << " outputType=" << originalOutput.getType(); + + MessageVector messages; + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceClass.cpus.front(), "host source core id"); + auto checkedTargetCpu = getCheckedCoreId(ownerClass.op, ownerClass.cpus.front(), "host target core id"); + if (failed(checkedSourceCpu) || failed(checkedTargetCpu)) + return failure(); + messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); + + if (failed(appendSend(state, sourceClass, payload, messages, payload.getLoc()))) + return failure(); + Value ownerPayload = appendReceive(state, ownerClass, payload.getType(), messages, payload.getLoc()); + return setHostOutputValue(state, ownerClass, originalOutput, ownerPayload); +} + +LogicalResult emitOutputFanout(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef keys, + Value payload, + Value originalOutput, + Location loc) { + if (keys.empty()) + return success(); + + if (!sourceClass.isBatch) { + if (failed(emitScalarSourceCommunication(state, sourceClass, keys, payload, loc))) + return failure(); + + return emitHostCommunication(state, sourceClass, payload, originalOutput); + } + + if (!haveSameDestinationClasses(state, keys)) + return sourceClass.op->emitError( + "cannot materialize batched output whose lanes have different destination equivalence classes"); + + for (ClassId destinationClass : getDestinationClasses(state, keys.front())) + if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) + return failure(); + + if (failed(emitHostCommunication(state, sourceClass, payload, originalOutput, keys))) + return failure(); + + for (ProducerKey key : keys) + state.availableValues.record(key, sourceClass.id, payload); + + return success(); +} + +struct DirectWholeBatchFragment { + ProducerKey key; + Value fragment; +}; + +enum class WholeBatchFragmentSourceKind { + DeferredReceive, + DeferredLocalCompute, + PackedValue, + DirectValue +}; + +struct WholeBatchFragmentGroup { + WholeBatchFragmentSourceKind kind = WholeBatchFragmentSourceKind::DirectValue; + RankedTensorType fragmentType; + SmallVector outputOffsets; + MessageVector messages; + Operation* sourceOp = nullptr; + size_t resultIndex = 0; + SmallVector sourceLanes; + Value packed; + RankedTensorType slotPackedType; + SmallVector slotIndices; + SmallVector, 16> directFragments; + SmallVector redundantReceives; +}; + +enum class ProjectedWholeBatchFragmentSourceKind { + DeferredReceive, + PackedValue, + DirectValue +}; + +struct ProjectedWholeBatchDirectFragment { + Value fragment; + SmallVector offsets; + SmallVector sizes; + SmallVector strides; +}; + +struct ProjectedWholeBatchFragmentGroup { + ProjectedWholeBatchFragmentSourceKind kind = ProjectedWholeBatchFragmentSourceKind::DirectValue; + RankedTensorType fragmentType; + SmallVector, 4> offsetsByDim; + SmallVector, 4> sizesByDim; + SmallVector, 4> stridesByDim; + MessageVector messages; + SmallVector redundantOps; + Value packed; + RankedTensorType packedSourceType; + SmallVector packedIndices; + SmallVector directFragments; +}; + +struct WholeBatchAssemblyPlan { + RankedTensorType resultType; + int64_t rowsPerLane = 0; + uint32_t batchLaneCount = 0; + uint32_t coveredLaneCount = 0; + + SmallVector coveredLanes; + SmallVector packedRuns; + SmallVector directFragments; +}; + +bool wholeBatchLaneCovered(const WholeBatchAssemblyPlan& plan, uint32_t lane) { + return lane < plan.coveredLanes.size() && plan.coveredLanes[lane] != 0; +} + +bool wholeBatchRangeOverlaps(const WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) { + if (laneCount == 0) + return false; + if (laneStart >= plan.coveredLanes.size()) + return false; + + uint32_t laneEnd = std::min(laneStart + laneCount, plan.coveredLanes.size()); + for (uint32_t lane = laneStart; lane < laneEnd; ++lane) + if (plan.coveredLanes[lane] != 0) + return true; + return false; +} + +void recordWholeBatchCoverage(WholeBatchAssemblyPlan& plan, uint32_t laneStart, uint32_t laneCount) { + assert(laneCount != 0 && "cannot cover an empty whole-batch range"); + assert(laneStart + laneCount <= plan.coveredLanes.size() && "whole-batch coverage out of bounds"); + + for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane) { + if (plan.coveredLanes[lane] != 0) + continue; + plan.coveredLanes[lane] = 1; + ++plan.coveredLaneCount; + } +} + +bool localLaneRangeOverlaps(ArrayRef covered, uint32_t laneStart, uint32_t laneCount) { + if (laneCount == 0) + return false; + if (laneStart >= covered.size()) + return false; + + uint32_t laneEnd = std::min(laneStart + laneCount, covered.size()); + for (uint32_t lane = laneStart; lane < laneEnd; ++lane) + if (covered[lane] != 0) + return true; + return false; +} + +void markLocalLaneRangeCovered(MutableArrayRef covered, uint32_t laneStart, uint32_t laneCount) { + assert(laneStart + laneCount <= covered.size() && "local coverage out of bounds"); + for (uint32_t lane = laneStart; lane < laneStart + laneCount; ++lane) + covered[lane] = 1; +} + +LogicalResult +validateWholeBatchFragmentType(RankedTensorType resultType, RankedTensorType fragmentType, int64_t expectedRows) { + if (!fragmentType.hasStaticShape()) + return failure(); + if (fragmentType.getRank() != resultType.getRank()) + return failure(); + if (fragmentType.getDimSize(0) != expectedRows) + return failure(); + + for (int64_t dim = 1; dim < resultType.getRank(); ++dim) + if (fragmentType.getDimSize(dim) != resultType.getDimSize(dim)) + return failure(); + + return success(); +} + +// ----------------------------------------------------------------------------- +// Packed run tensor assembly helpers. +// ----------------------------------------------------------------------------- + +FailureOr insertFragmentIntoWholeBatch(MaterializerState& state, + MaterializedClass& targetClass, + Value fragment, + Value destination, + OpFoldResult firstOffset, + Location loc) { + return createDim0InsertSliceInClass(state, targetClass, loc, fragment, destination, firstOffset); +} + +FailureOr extractPackedSlotForIndex(MaterializerState& state, + MaterializedClass& targetClass, + Value packed, + RankedTensorType slotPackedType, + Value slotIndex, + Location loc) { + FailureOr firstOffset = + scaleIndexByDim0SizeInClass(state, targetClass, slotIndex, slotPackedType.getDimSize(0), loc); + if (failed(firstOffset)) + return failure(); + return createDim0ExtractSliceInClass(state, targetClass, loc, packed, *firstOffset, slotPackedType.getDimSize(0)); +} + +SmallVector flattenPackedScalarRunKeys(const PackedScalarRunValue& run) { + SmallVector keys; + for (const PackedScalarRunSlot& slot : run.slots) + llvm::append_range(keys, slot.keys); + return keys; +} + +bool packedScalarRunSlotsMatch(const PackedScalarRunValue& lhs, const PackedScalarRunValue& rhs) { + if (lhs.slots.size() != rhs.slots.size()) + return false; + + for (auto [lhsSlot, rhsSlot] : llvm::zip(lhs.slots, rhs.slots)) { + if (lhsSlot.keys.size() != rhsSlot.keys.size()) + return false; + if (!llvm::equal(lhsSlot.keys, rhsSlot.keys)) + return false; + } + + return true; +} + + +std::optional getConstantIndexValue(Value value) { + APInt constant; + if (matchPattern(value, m_ConstantInt(&constant))) + return constant.getSExtValue(); + return std::nullopt; +} + +bool appendConstantChannelReceiveMessage(MessageVector& messages, SpatChannelReceiveOp receive) { + std::optional channelId = getConstantIndexValue(receive.getChannelId()); + std::optional sourceCoreId = getConstantIndexValue(receive.getSourceCoreId()); + std::optional targetCoreId = getConstantIndexValue(receive.getTargetCoreId()); + if (!channelId || !sourceCoreId || !targetCoreId) + return false; + messages.append(*channelId, static_cast(*sourceCoreId), static_cast(*targetCoreId)); + return true; +} + +PackedScalarRunValue* findDeferredReceiveAlternativeForPackedRun(MaterializerState& state, + const MaterializedClass& targetClass, + const PackedScalarRunValue& run) { + WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(run.sourceOp, run.resultIndex, targetClass.id); + ArrayRef runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey); + + for (size_t runIndex : runIndices) { + PackedScalarRunValue& candidate = state.availableValues.getPackedRun(runIndex); + if (&candidate == &run || candidate.kind != PackedScalarRunKind::DeferredReceive) + continue; + if (candidate.fragmentType != run.fragmentType) + continue; + if (!packedScalarRunSlotsMatch(candidate, run)) + continue; + return &candidate; + } + + return nullptr; +} + +FailureOr emitIndexedFragmentInsertLoop(MaterializerState& state, + MaterializedClass& targetClass, + Value destination, + int64_t itemCount, + IndexedFragmentBuilder buildFragment, + IndexedInsertOffsetBuilder buildOffset, + Location loc) { + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, itemCount); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + Operation* insertionPoint = targetClass.body->getTerminator(); + + state.rewriter.setInsertionPoint(insertionPoint); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {destination}, + [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + FailureOr fragment = buildFragment(flatIndex); + if (failed(fragment)) + return failure(); + FailureOr offset = buildOffset(flatIndex); + if (failed(offset)) + return failure(); + FailureOr next = + insertFragmentIntoWholeBatch(state, targetClass, *fragment, iterArgs.front(), *offset, loc); + if (failed(next)) + return failure(); + yielded.push_back(*next); + return success(); + }); + if (failed(loop)) + return failure(); + return loop->results.front(); +} + +FailureOr> cloneBatchBodyForLane(MaterializerState& state, + MaterializedClass& targetClass, + const ComputeInstance& instance, + Value laneValue, + ArrayRef resultIndices, + CloneIndexingContext indexing = {}); + +Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targetClass, Value slotIndex, Location loc); +FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, + MaterializedClass& targetClass, + IndexedBatchRunValue& run, + Value runSlotIndex, + Location loc); + +FailureOr materializeDeferredLocalPackedScalarRunValue(MaterializerState& state, + MaterializedClass& targetClass, + PackedScalarRunValue& run, + Location loc) { + assert(isDeferredLocalPackedScalarRun(run) && "expected deferred local packed scalar run"); + + SmallVector keys = flattenPackedScalarRunKeys(run); + if (keys.empty()) + return failure(); + FailureOr packedType = getPackedBatchTensorType(run.fragmentType, keys.size()); + if (failed(packedType)) + return targetClass.op->emitError("cannot materialize deferred local packed run for non-static ranked tensor"); + + SmallVector sourceLanes; + sourceLanes.reserve(keys.size()); + for (ProducerKey key : keys) { + if (key.instance.laneCount != 1) + return failure(); + sourceLanes.push_back(key.instance.laneStart); + } + + SmallVector resultIndices {run.resultIndex}; + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value init = + tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult(); + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(keys.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {init}, + [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value acc = iterArgs.front(); + Value sourceLane = createIndexedIndexValue(state, targetClass.op, sourceLanes, loopIndex, loc); + + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + keys.front().instance, + sourceLane, + resultIndices, + CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = loopIndex}); + if (failed(produced) || produced->size() != 1) + return failure(); + + FailureOr firstOffset = + scaleIndexByDim0SizeInClass(state, targetClass, loopIndex, run.fragmentType.getDimSize(0), loc); + if (failed(firstOffset)) + return failure(); + FailureOr next = createDim0InsertSliceInClass(state, targetClass, loc, produced->front(), acc, *firstOffset); + if (failed(next)) + return failure(); + yielded.push_back(*next); + return success(); + }); + if (failed(loop)) + return failure(); + run.packed = loop->results.front(); + return run.packed; +} + +LogicalResult collectPackedRunsForWholeBatchInput(MaterializerState& state, + MaterializedClass& targetClass, + ProducerKey key, + WholeBatchAssemblyPlan& plan) { + WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id); + ArrayRef runIndices = state.availableValues.getPackedRunIndicesForWholeBatch(lookupKey); + + for (size_t runIndex : runIndices) { + PackedScalarRunValue& run = state.availableValues.getPackedRun(runIndex); + + SmallVector runKeys; + SmallVector runCoveredLanes(plan.batchLaneCount, 0); + + for (const PackedScalarRunSlot& slot : run.slots) { + for (ProducerKey fragmentKey : slot.keys) { + if (fragmentKey.instance.op != key.instance.op || fragmentKey.resultIndex != key.resultIndex) + return failure(); + + if (fragmentKey.instance.laneCount == 0) + return failure(); + + if (wholeBatchRangeOverlaps(plan, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount)) + return failure(); + + if (localLaneRangeOverlaps(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount)) + return failure(); + + markLocalLaneRangeCovered(runCoveredLanes, fragmentKey.instance.laneStart, fragmentKey.instance.laneCount); + runKeys.push_back(fragmentKey); + } + } + + if (runKeys.empty()) + continue; + + plan.packedRuns.push_back(&run); + + for (ProducerKey runKey : runKeys) + recordWholeBatchCoverage(plan, runKey.instance.laneStart, runKey.instance.laneCount); + } + + return success(); +} + +LogicalResult collectDirectFragmentsForWholeBatchInput(MaterializerState& state, + MaterializedClass& targetClass, + SpatComputeBatch batch, + ProducerKey key, + WholeBatchAssemblyPlan& plan) { + struct CandidateFragment { + ProducerKey key; + Value value; + }; + + uint32_t batchLaneCount = static_cast(batch.getLaneCount()); + if (plan.coveredLaneCount == plan.batchLaneCount) { + return success(); + } + + WholeBatchAssemblyLookupKey lookupKey = makeWholeBatchAssemblyLookupKey(key, targetClass.id); + ArrayRef indexedFragments = + state.availableValues.getExactFragmentsForWholeBatch(lookupKey); + + SmallVector candidates; + candidates.reserve(indexedFragments.size()); + for (const AvailableValueStore::ExactBatchFragmentRecord& record : indexedFragments) { + ProducerKey candidateKey = record.key; + if (candidateKey.instance.op != batch.getOperation() || candidateKey.resultIndex != key.resultIndex + || candidateKey.instance.laneCount == 0) + continue; + if (!isTensorValueLocalToMaterializedClass(record.value, targetClass)) + continue; + if (wholeBatchRangeOverlaps(plan, candidateKey.instance.laneStart, candidateKey.instance.laneCount)) + continue; + + auto fragmentType = dyn_cast(record.value.getType()); + if (!fragmentType) + continue; + + int64_t expectedRows = plan.rowsPerLane * static_cast(candidateKey.instance.laneCount); + if (failed(validateWholeBatchFragmentType(plan.resultType, fragmentType, expectedRows))) + continue; + + candidates.push_back({candidateKey, record.value}); + } + + llvm::sort(candidates, [](const CandidateFragment& lhs, const CandidateFragment& rhs) { + if (lhs.key.instance.laneStart != rhs.key.instance.laneStart) + return lhs.key.instance.laneStart < rhs.key.instance.laneStart; + return lhs.key.instance.laneCount > rhs.key.instance.laneCount; + }); + + size_t candidateCursor = 0; + uint32_t lane = 0; + while (lane < batchLaneCount) { + while (lane < batchLaneCount && wholeBatchLaneCovered(plan, lane)) { + ++lane; + } + + if (lane >= batchLaneCount) + break; + + while (candidateCursor < candidates.size() && candidates[candidateCursor].key.instance.laneStart < lane) + ++candidateCursor; + + size_t candidateIndex = candidateCursor; + const CandidateFragment* best = nullptr; + while (candidateIndex < candidates.size() && candidates[candidateIndex].key.instance.laneStart == lane) { + const CandidateFragment& candidate = candidates[candidateIndex]; + if (!wholeBatchRangeOverlaps(plan, lane, candidate.key.instance.laneCount)) { + best = &candidate; + break; + } + ++candidateIndex; + } + + if (!best) + return failure(); + + plan.directFragments.push_back({best->key, best->value}); + recordWholeBatchCoverage(plan, lane, best->key.instance.laneCount); + lane += best->key.instance.laneCount; + } + + return success(); +} + +LogicalResult collectWholeBatchFragmentGroups(MaterializerState& state, + MaterializedClass& targetClass, + const WholeBatchAssemblyPlan& plan, + SmallVectorImpl& groups) { + for (PackedScalarRunValue* run : plan.packedRuns) { + if (!run || run->slots.empty()) + continue; + if (run->fragmentType.getDimSize(0) != plan.rowsPerLane) + return failure(); + + if (run->kind == PackedScalarRunKind::Materialized && run->packed + && !isTensorValueLocalToMaterializedClass(run->packed, targetClass)) { + if (PackedScalarRunValue* deferredRun = findDeferredReceiveAlternativeForPackedRun(state, targetClass, *run)) + run = deferredRun; + else { + SmallVector keys = flattenPackedScalarRunKeys(*run); + std::optional packedKey = getContiguousProducerRangeForKeys(keys); + emitNonLocalMaterializedClassValueDiagnostic(targetClass.op, + targetClass, + "whole-batch assembly tried to reuse non-local PackedValue", + run->packed, + packedKey); + return failure(); + } + } + + if (run->kind == PackedScalarRunKind::DeferredReceive) { + if (failed(validatePackedScalarRunMetadata(targetClass.op, *run))) + return failure(); + + auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { + return group.kind == WholeBatchFragmentSourceKind::DeferredReceive && group.fragmentType == run->fragmentType; + }); + if (groupIt == groups.end()) { + WholeBatchFragmentGroup group; + group.kind = WholeBatchFragmentSourceKind::DeferredReceive; + group.fragmentType = run->fragmentType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + + groupIt->messages.append(run->messages.channelIds, run->messages.sourceCoreIds, run->messages.targetCoreIds); + for (const PackedScalarRunSlot& slot : run->slots) + for (ProducerKey fragmentKey : slot.keys) + groupIt->outputOffsets.push_back(static_cast(fragmentKey.instance.laneStart) * plan.rowsPerLane); + continue; + } + + if (run->kind == PackedScalarRunKind::DeferredLocalCompute) { + SmallVector keys = flattenPackedScalarRunKeys(*run); + if (keys.empty()) + return failure(); + + auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { + return group.kind == WholeBatchFragmentSourceKind::DeferredLocalCompute + && group.fragmentType == run->fragmentType && group.sourceOp == run->sourceOp + && group.resultIndex == run->resultIndex; + }); + if (groupIt == groups.end()) { + WholeBatchFragmentGroup group; + group.kind = WholeBatchFragmentSourceKind::DeferredLocalCompute; + group.fragmentType = run->fragmentType; + group.sourceOp = run->sourceOp; + group.resultIndex = run->resultIndex; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + + for (ProducerKey fragmentKey : keys) { + if (fragmentKey.instance.laneCount != 1) + return failure(); + groupIt->sourceLanes.push_back(fragmentKey.instance.laneStart); + groupIt->outputOffsets.push_back(static_cast(fragmentKey.instance.laneStart) * plan.rowsPerLane); + } + continue; + } + + auto sourceBatch = dyn_cast_or_null(run->sourceOp); + if (!sourceBatch || !run->packed) + return failure(); + + auto getOrCreatePackedValueGroup = [&](RankedTensorType slotPackedType) -> WholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { + return group.kind == WholeBatchFragmentSourceKind::PackedValue && group.fragmentType == run->fragmentType + && group.packed == run->packed && group.slotPackedType == slotPackedType; + }); + if (groupIt == groups.end()) { + WholeBatchFragmentGroup group; + group.kind = WholeBatchFragmentSourceKind::PackedValue; + group.fragmentType = run->fragmentType; + group.packed = run->packed; + group.slotPackedType = slotPackedType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + + size_t flattenedIndexBase = 0; + for (auto [slotIndex, slot] : llvm::enumerate(run->slots)) { + std::optional contiguousKey = getContiguousProducerRangeForKeys(slot.keys); + if (contiguousKey) { + FailureOr slotPackedType = getPackedBatchTensorType(run->fragmentType, slot.keys.size()); + if (failed(slotPackedType)) + return failure(); + WholeBatchFragmentGroup& group = getOrCreatePackedValueGroup(*slotPackedType); + group.slotIndices.push_back(slotIndex); + group.outputOffsets.push_back(static_cast(contiguousKey->instance.laneStart) * plan.rowsPerLane); + flattenedIndexBase += slot.keys.size(); + continue; + } + + WholeBatchFragmentGroup& group = getOrCreatePackedValueGroup(run->fragmentType); + for (auto [keyIndex, fragmentKey] : llvm::enumerate(slot.keys)) { + group.slotIndices.push_back(flattenedIndexBase + keyIndex); + group.outputOffsets.push_back(static_cast(fragmentKey.instance.laneStart) * plan.rowsPerLane); + } + flattenedIndexBase += slot.keys.size(); + } + } + + auto getOrCreateDeferredReceiveGroup = [&](RankedTensorType fragmentType) -> WholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { + return group.kind == WholeBatchFragmentSourceKind::DeferredReceive && group.fragmentType == fragmentType; + }); + if (groupIt == groups.end()) { + WholeBatchFragmentGroup group; + group.kind = WholeBatchFragmentSourceKind::DeferredReceive; + group.fragmentType = fragmentType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + + auto getOrCreateDirectValueGroup = [&](RankedTensorType fragmentType) -> WholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const WholeBatchFragmentGroup& group) { + return group.kind == WholeBatchFragmentSourceKind::DirectValue && group.fragmentType == fragmentType; + }); + if (groupIt == groups.end()) { + WholeBatchFragmentGroup group; + group.kind = WholeBatchFragmentSourceKind::DirectValue; + group.fragmentType = fragmentType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + + for (const DirectWholeBatchFragment& fragment : plan.directFragments) { + if (!isTensorValueLocalToMaterializedClass(fragment.fragment, targetClass)) { + emitNonLocalMaterializedClassValueDiagnostic(targetClass.op, + targetClass, + "whole-batch assembly tried to reuse non-local DirectValue", + fragment.fragment, + fragment.key); + return failure(); + } + + auto fragmentType = dyn_cast(fragment.fragment.getType()); + if (!fragmentType) + return failure(); + + int64_t outputOffset = static_cast(fragment.key.instance.laneStart) * plan.rowsPerLane; + + if (auto receive = fragment.fragment.getDefiningOp()) { + if (fragment.fragment.use_empty()) { + WholeBatchFragmentGroup& group = getOrCreateDeferredReceiveGroup(fragmentType); + if (appendConstantChannelReceiveMessage(group.messages, receive)) { + group.outputOffsets.push_back(outputOffset); + group.redundantReceives.push_back(receive.getOperation()); + continue; + } + } + } + + WholeBatchFragmentGroup& group = getOrCreateDirectValueGroup(fragmentType); + group.directFragments.push_back({fragment.fragment, outputOffset}); + } + + return success(); +} + +FailureOr emitWholeBatchFragmentGroup(MaterializerState& state, + MaterializedClass& targetClass, + Value destination, + const WholeBatchFragmentGroup& group, + Location loc) { + switch (group.kind) { + case WholeBatchFragmentSourceKind::DeferredReceive: { + FailureOr updated = emitIndexedFragmentInsertLoop( + state, + targetClass, + destination, + static_cast(group.outputOffsets.size()), + [&](Value flatIndex) -> FailureOr { + Value channelId = createIndexedChannelId(state, targetClass.op, group.messages, flatIndex, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, group.messages, flatIndex, loc); + Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, group.messages, flatIndex, loc); + return SpatChannelReceiveOp::create( + state.rewriter, loc, group.fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); + }, + [&](Value flatIndex) -> FailureOr { + return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); + }, + loc); + if (failed(updated)) + return failure(); + + for (Operation* receive : group.redundantReceives) + if (receive && receive->use_empty()) + receive->erase(); + + return *updated; + } + case WholeBatchFragmentSourceKind::DeferredLocalCompute: { + SmallVector resultIndices {group.resultIndex}; + return emitIndexedFragmentInsertLoop( + state, + targetClass, + destination, + static_cast(group.outputOffsets.size()), + [&](Value flatIndex) -> FailureOr { + Value sourceLane = createIndexedIndexValue(state, targetClass.op, group.sourceLanes, flatIndex, loc); + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + ComputeInstance {group.sourceOp, 0, 1}, + sourceLane, + resultIndices, + CloneIndexingContext {.runSlotIndex = flatIndex, .projectionSlotIndex = flatIndex}); + if (failed(produced) || produced->size() != 1) + return failure(); + return produced->front(); + }, + [&](Value flatIndex) -> FailureOr { + return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); + }, + loc); + } + case WholeBatchFragmentSourceKind::PackedValue: + return emitIndexedFragmentInsertLoop( + state, + targetClass, + destination, + static_cast(group.slotIndices.size()), + [&](Value flatIndex) -> FailureOr { + Value packedSlotIndex = createIndexedIndexValue(state, targetClass.op, group.slotIndices, flatIndex, loc); + FailureOr packed = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + group.packed, + targetClass.op, + "whole-batch packed fragment assembly tried to reuse a tensor from another materialized class"); + if (failed(packed)) + return failure(); + return extractPackedSlotForIndex(state, targetClass, *packed, group.slotPackedType, packedSlotIndex, loc); + }, + [&](Value flatIndex) -> FailureOr { + return createIndexedIndexValue(state, targetClass.op, group.outputOffsets, flatIndex, loc); + }, + loc); + case WholeBatchFragmentSourceKind::DirectValue: + for (const auto& [fragment, offset] : group.directFragments) { + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + FailureOr localFragment = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + fragment, + targetClass.op, + "whole-batch direct fragment assembly tried to reuse a tensor from another materialized class"); + if (failed(localFragment)) + return failure(); + FailureOr updated = createDim0InsertSliceInClass(state, + targetClass, + loc, + *localFragment, + destination, + getOrCreateIndexConstant(state.constantFolder, targetClass.op, offset)); + if (failed(updated)) + return failure(); + destination = *updated; + } + return destination; + } + + return failure(); +} + +FailureOr emitProjectedWholeBatchFragmentInsertLoop( + MaterializerState& state, + MaterializedClass& targetClass, + Value destination, + const ProjectedWholeBatchFragmentGroup& group, + llvm::function_ref(Value)> buildFragment, + Location loc) { + assert(group.fragmentType && "expected projected fragment type"); + assert(!group.offsetsByDim.empty() && "expected projected insert coordinates"); + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, group.offsetsByDim.front().size()); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {destination}, + [&](OpBuilder&, Location, Value flatIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + FailureOr fragment = buildFragment(flatIndex); + if (failed(fragment)) + return failure(); + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + unsigned rank = group.offsetsByDim.size(); + offsets.reserve(rank); + sizes.reserve(rank); + strides.reserve(rank); + for (unsigned dim = 0; dim < rank; ++dim) { + offsets.push_back(createIndexedOrStaticIndex(state, targetClass.op, group.offsetsByDim[dim], flatIndex, loc)); + sizes.push_back(createIndexedOrStaticIndex(state, targetClass.op, group.sizesByDim[dim], flatIndex, loc)); + strides.push_back(createIndexedOrStaticIndex(state, targetClass.op, group.stridesByDim[dim], flatIndex, loc)); + } + + Value updated = + tensor::InsertSliceOp::create(state.rewriter, loc, *fragment, iterArgs.front(), offsets, sizes, strides) + .getResult(); + yielded.push_back(updated); + return success(); + }); + if (failed(loop)) + return failure(); + return loop->results.front(); +} + +std::optional getStaticProjectedPackedFragmentIndex(tensor::ExtractSliceOp extract) { + auto sourceType = dyn_cast(extract.getSource().getType()); + auto resultType = dyn_cast(extract.getResult().getType()); + if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape() + || sourceType.getRank() == 0 || sourceType.getRank() != resultType.getRank()) + return std::nullopt; + + std::optional firstOffset = getConstantIndex(extract.getMixedOffsets().front()); + if (!firstOffset) + return std::nullopt; + + for (int64_t dim = 0; dim < sourceType.getRank(); ++dim) { + std::optional offset = getConstantIndex(extract.getMixedOffsets()[dim]); + std::optional size = getConstantIndex(extract.getMixedSizes()[dim]); + std::optional stride = getConstantIndex(extract.getMixedStrides()[dim]); + if (!offset || !size || !stride || *stride != 1 || *size != resultType.getDimSize(dim)) + return std::nullopt; + if (dim != 0 && *offset != 0) + return std::nullopt; + } + + return *firstOffset; +} + +void appendProjectedInsertCoordinates(ProjectedWholeBatchFragmentGroup& group, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + if (group.offsetsByDim.empty()) { + size_t rank = offsets.size(); + group.offsetsByDim.resize(rank); + group.sizesByDim.resize(rank); + group.stridesByDim.resize(rank); + } + + for (size_t dim = 0; dim < offsets.size(); ++dim) { + group.offsetsByDim[dim].push_back(offsets[dim]); + group.sizesByDim[dim].push_back(sizes[dim]); + group.stridesByDim[dim].push_back(strides[dim]); + } +} + +FailureOr buildWholeBatchAssemblyPlan(MaterializerState& state, + MaterializedClass& targetClass, + ProducerKey key, + Type resultType) { + auto batch = dyn_cast_or_null(key.instance.op); + auto resultTensorType = dyn_cast(resultType); + if (!batch || !resultTensorType || !resultTensorType.hasStaticShape() || resultTensorType.getRank() == 0) + return failure(); + + uint32_t batchLaneCount = static_cast(batch.getLaneCount()); + if (batchLaneCount == 0 || resultTensorType.getDimSize(0) % static_cast(batchLaneCount) != 0) + return failure(); + + WholeBatchAssemblyPlan plan; + plan.resultType = resultTensorType; + plan.rowsPerLane = resultTensorType.getDimSize(0) / static_cast(batchLaneCount); + plan.batchLaneCount = batchLaneCount; + plan.coveredLanes.assign(batchLaneCount, 0); + + if (failed(collectPackedRunsForWholeBatchInput(state, targetClass, key, plan))) + return failure(); + + if (plan.coveredLaneCount == plan.batchLaneCount) + return plan; + + if (failed(collectDirectFragmentsForWholeBatchInput(state, targetClass, batch, key, plan))) + return failure(); + + return plan; +} + +FailureOr emitWholeBatchAssemblyPlan(MaterializerState& state, + MaterializedClass& targetClass, + ProducerKey key, + WholeBatchAssemblyPlan& plan, + Location loc) { + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value result = + tensor::EmptyOp::create(state.rewriter, loc, plan.resultType.getShape(), plan.resultType.getElementType()) + .getResult(); + + SmallVector groups; + if (failed(collectWholeBatchFragmentGroups(state, targetClass, plan, groups))) + return failure(); + + for (const WholeBatchFragmentGroup& group : groups) { + FailureOr updated = emitWholeBatchFragmentGroup(state, targetClass, result, group, loc); + if (failed(updated)) + return failure(); + result = *updated; + } + + state.availableValues.record(key, targetClass.id, result); + return result; +} + +// ----------------------------------------------------------------------------- +// Run materialization helpers. +// ----------------------------------------------------------------------------- + +FailureOr materializeProjectedWholeBatchInputFromFragments(MaterializerState& state, + MaterializedClass& targetClass, + ProducerKey key, + Type resultType, + Location loc) { + auto batch = dyn_cast_or_null(key.instance.op); + auto resultTensorType = dyn_cast(resultType); + if (!batch || !resultTensorType || !resultTensorType.hasStaticShape()) + return failure(); + + FailureOr projection = getBatchResultProjectionInsert(batch, key.resultIndex); + if (failed(projection)) + return failure(); + + auto laneArg = batch.getLaneArgument(); + if (!laneArg) + return batch.emitOpError("missing compute_batch lane argument while materializing projected whole-batch input"); + + uint32_t laneEnd = key.instance.laneStart + key.instance.laneCount; + if (laneEnd > static_cast(batch.getLaneCount())) + return failure(); + + if (targetClass.isBatch) { + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value result = + tensor::EmptyOp::create(state.rewriter, loc, resultTensorType.getShape(), resultTensorType.getElementType()) + .getResult(); + + for (uint32_t lane = key.instance.laneStart; lane < laneEnd; ++lane) { + ProducerKey laneKey = getBatchLaneProducerKey(batch, lane, 1, key.resultIndex); + std::optional fragment = state.availableValues.lookup(state, laneKey, targetClass.id); + if (!fragment) + return failure(); + + FailureOr> offsets = + evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, lane); + FailureOr> sizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, lane); + FailureOr> strides = + evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, lane); + if (failed(offsets) || failed(sizes) || failed(strides)) + return failure(); + + SmallVector offsetAttrs; + SmallVector sizeAttrs; + SmallVector strideAttrs; + offsetAttrs.reserve(offsets->size()); + sizeAttrs.reserve(sizes->size()); + strideAttrs.reserve(strides->size()); + for (auto [offset, size, stride] : llvm::zip(*offsets, *sizes, *strides)) { + offsetAttrs.push_back(state.rewriter.getIndexAttr(offset)); + sizeAttrs.push_back(state.rewriter.getIndexAttr(size)); + strideAttrs.push_back(state.rewriter.getIndexAttr(stride)); + } + + FailureOr localFragment = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + *fragment, + targetClass.op, + "projected whole-batch assembly tried to reuse a tensor from another materialized class", + laneKey); + if (failed(localFragment)) + return failure(); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + result = tensor::InsertSliceOp::create( + state.rewriter, loc, *localFragment, result, offsetAttrs, sizeAttrs, strideAttrs) + .getResult(); + } + + state.availableValues.record(key, targetClass.id, result); + return result; + } + + SmallVector groups; + auto getOrCreateReceiveGroup = [&](RankedTensorType fragmentType) -> ProjectedWholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const ProjectedWholeBatchFragmentGroup& group) { + return group.kind == ProjectedWholeBatchFragmentSourceKind::DeferredReceive && group.fragmentType == fragmentType; + }); + if (groupIt == groups.end()) { + ProjectedWholeBatchFragmentGroup group; + group.kind = ProjectedWholeBatchFragmentSourceKind::DeferredReceive; + group.fragmentType = fragmentType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + auto getOrCreatePackedGroup = [&](Value packed, + RankedTensorType packedSourceType, + RankedTensorType fragmentType) -> ProjectedWholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const ProjectedWholeBatchFragmentGroup& group) { + return group.kind == ProjectedWholeBatchFragmentSourceKind::PackedValue && group.fragmentType == fragmentType + && group.packed == packed && group.packedSourceType == packedSourceType; + }); + if (groupIt == groups.end()) { + ProjectedWholeBatchFragmentGroup group; + group.kind = ProjectedWholeBatchFragmentSourceKind::PackedValue; + group.fragmentType = fragmentType; + group.packed = packed; + group.packedSourceType = packedSourceType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + auto getOrCreateDirectGroup = [&](RankedTensorType fragmentType) -> ProjectedWholeBatchFragmentGroup& { + auto groupIt = llvm::find_if(groups, [&](const ProjectedWholeBatchFragmentGroup& group) { + return group.kind == ProjectedWholeBatchFragmentSourceKind::DirectValue && group.fragmentType == fragmentType; + }); + if (groupIt == groups.end()) { + ProjectedWholeBatchFragmentGroup group; + group.kind = ProjectedWholeBatchFragmentSourceKind::DirectValue; + group.fragmentType = fragmentType; + groups.push_back(std::move(group)); + groupIt = std::prev(groups.end()); + } + return *groupIt; + }; + + for (uint32_t lane = key.instance.laneStart; lane < laneEnd; ++lane) { + ProducerKey laneKey = getBatchLaneProducerKey(batch, lane, 1, key.resultIndex); + FailureOr> offsets = + evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, lane); + FailureOr> sizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, lane); + FailureOr> strides = + evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, lane); + if (failed(offsets) || failed(sizes) || failed(strides)) + return failure(); + + bool grouped = false; + if (std::optional exact = state.availableValues.lookupExact(laneKey, targetClass.id)) { + if (auto receive = exact->getDefiningOp()) { + auto fragmentType = dyn_cast(receive.getOutput().getType()); + if (fragmentType && receive.getOutput().use_empty()) { + ProjectedWholeBatchFragmentGroup& group = getOrCreateReceiveGroup(fragmentType); + if (appendConstantChannelReceiveMessage(group.messages, receive)) { + appendProjectedInsertCoordinates(group, *offsets, *sizes, *strides); + group.redundantOps.push_back(receive.getOperation()); + grouped = true; + } + } + } + } + + if (grouped) + continue; + + std::optional fragment = state.availableValues.lookup(state, laneKey, targetClass.id); + if (!fragment) + return failure(); + + auto fragmentType = dyn_cast(fragment->getType()); + if (!fragmentType) + return failure(); + + if (auto extract = fragment->getDefiningOp()) { + if (std::optional packedIndex = getStaticProjectedPackedFragmentIndex(extract)) { + auto packedSourceType = dyn_cast(extract.getSource().getType()); + if (packedSourceType) { + ProjectedWholeBatchFragmentGroup& group = + getOrCreatePackedGroup(extract.getSource(), packedSourceType, fragmentType); + group.packedIndices.push_back(*packedIndex); + appendProjectedInsertCoordinates(group, *offsets, *sizes, *strides); + group.redundantOps.push_back(extract.getOperation()); + continue; + } + } + } + + ProjectedWholeBatchFragmentGroup& group = getOrCreateDirectGroup(fragmentType); + group.directFragments.push_back({*fragment, *offsets, *sizes, *strides}); + } + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value result = + tensor::EmptyOp::create(state.rewriter, loc, resultTensorType.getShape(), resultTensorType.getElementType()) + .getResult(); + + for (const ProjectedWholeBatchFragmentGroup& group : groups) { + FailureOr updated = failure(); + switch (group.kind) { + case ProjectedWholeBatchFragmentSourceKind::DeferredReceive: + updated = emitProjectedWholeBatchFragmentInsertLoop( + state, + targetClass, + result, + group, + [&](Value flatIndex) -> FailureOr { + Value channelId = createIndexedChannelId(state, targetClass.op, group.messages, flatIndex, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, group.messages, flatIndex, loc); + Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, group.messages, flatIndex, loc); + return SpatChannelReceiveOp::create( + state.rewriter, loc, group.fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); + }, + loc); + break; + case ProjectedWholeBatchFragmentSourceKind::PackedValue: + updated = emitProjectedWholeBatchFragmentInsertLoop( + state, + targetClass, + result, + group, + [&](Value flatIndex) -> FailureOr { + SmallVector extractOffsets; + SmallVector extractSizes; + SmallVector extractStrides; + extractOffsets.reserve(group.packedSourceType.getRank()); + extractSizes.reserve(group.packedSourceType.getRank()); + extractStrides.reserve(group.packedSourceType.getRank()); + extractOffsets.push_back(createIndexedOrStaticIndex( + state, targetClass.op, group.packedIndices, flatIndex, loc)); + extractSizes.push_back(state.rewriter.getIndexAttr(1)); + extractStrides.push_back(state.rewriter.getIndexAttr(1)); + for (int64_t dim = 1; dim < group.packedSourceType.getRank(); ++dim) { + extractOffsets.push_back(state.rewriter.getIndexAttr(0)); + extractSizes.push_back(state.rewriter.getIndexAttr(group.packedSourceType.getDimSize(dim))); + extractStrides.push_back(state.rewriter.getIndexAttr(1)); + } + + FailureOr packed = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + group.packed, + targetClass.op, + "projected whole-batch packed fragment assembly tried to reuse a tensor from another materialized class"); + if (failed(packed)) + return failure(); + + return tensor::ExtractSliceOp::create( + state.rewriter, + loc, + group.fragmentType, + *packed, + extractOffsets, + extractSizes, + extractStrides) + .getResult(); + }, + loc); + break; + case ProjectedWholeBatchFragmentSourceKind::DirectValue: { + updated = result; + for (const ProjectedWholeBatchDirectFragment& fragment : group.directFragments) { + FailureOr localFragment = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + fragment.fragment, + targetClass.op, + "projected whole-batch assembly tried to reuse a tensor from another materialized class"); + if (failed(localFragment)) + return failure(); + + SmallVector offsetAttrs; + SmallVector sizeAttrs; + SmallVector strideAttrs; + for (auto [offset, size, stride] : llvm::zip(fragment.offsets, fragment.sizes, fragment.strides)) { + offsetAttrs.push_back(state.rewriter.getIndexAttr(offset)); + sizeAttrs.push_back(state.rewriter.getIndexAttr(size)); + strideAttrs.push_back(state.rewriter.getIndexAttr(stride)); + } + updated = tensor::InsertSliceOp::create( + state.rewriter, loc, *localFragment, *updated, offsetAttrs, sizeAttrs, strideAttrs) + .getResult(); + } + break; + } + } + if (failed(updated)) + return failure(); + result = *updated; + } + + for (const ProjectedWholeBatchFragmentGroup& group : groups) + for (Operation* redundantOp : group.redundantOps) + if (redundantOp && redundantOp->use_empty()) + redundantOp->erase(); + + state.availableValues.record(key, targetClass.id, result); + return result; +} + +FailureOr materializeWholeBatchInput( + MaterializerState& state, MaterializedClass& targetClass, ProducerKey key, Type resultType, Location loc) { + FailureOr plan = buildWholeBatchAssemblyPlan(state, targetClass, key, resultType); + if (succeeded(plan)) + return emitWholeBatchAssemblyPlan(state, targetClass, key, *plan, loc); + + return materializeProjectedWholeBatchInputFromFragments(state, targetClass, key, resultType, loc); +} + +FailureOr recordProjectedScalarHostFragmentsFromPackedRun(MaterializerState& state, + MaterializedClass& sourceClass, + SpatComputeBatch sourceBatch, + size_t resultIndex, + ArrayRef run, + Value packed, + RankedTensorType fragmentType, + Value originalOutput, + Location loc) { + if (!hasLiveExternalUseCached(state, originalOutput)) + return false; + if (packed.getType() == originalOutput.getType() || fragmentType == originalOutput.getType()) + return false; + + auto resultType = dyn_cast(originalOutput.getType()); + if (!resultType || !resultType.hasStaticShape()) + return false; + + FailureOr projection = getBatchResultProjectionInsert(sourceBatch, resultIndex); + if (failed(projection)) + return false; + + std::optional laneArg = sourceBatch.getLaneArgument(); + if (!laneArg) { + sourceBatch.emitOpError("missing compute_batch lane argument while recording projected host fragments"); + return failure(); + } + + for (auto [runIndex, slot] : llvm::enumerate(run)) { + if (slot.peers.size() != 1) { + sourceClass.op->emitError("projected scalar host output publication expects scalar one-peer run slots"); + return failure(); + } + + const ComputeInstance& peer = slot.peers.front(); + if (peer.op != sourceBatch.getOperation()) { + sourceClass.op->emitError("projected scalar host output run changed source operation"); + return failure(); + } + if (peer.laneCount != 1) { + sourceClass.op->emitError("projected scalar host output publication expects one logical lane per packed slot") + << " laneStart=" << peer.laneStart << " laneCount=" << peer.laneCount; + return failure(); + } + + FailureOr> offsets = + evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, peer.laneStart); + FailureOr> sizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, peer.laneStart); + FailureOr> strides = + evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, peer.laneStart); + if (failed(offsets) || failed(sizes) || failed(strides)) { + sourceClass.op->emitError("failed to evaluate projected host output slice for logical lane ") + << peer.laneStart; + return failure(); + } + + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + Value fragment = getPackedSliceForRunIndex(state, sourceClass.op, packed, fragmentType, runIndex, loc); + + state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment { + originalOutput, + sourceClass.id, + fragment, + fragmentType, + SmallVector(*offsets), + SmallVector(*sizes), + SmallVector(*strides), + peer.laneStart, + loc}); + } + + return true; +} + +std::optional getOriginalOutputResultIndex(Value originalOutput) { + auto result = dyn_cast(originalOutput); + if (!result) + return std::nullopt; + return static_cast(result.getResultNumber()); +} + +FailureOr recordProjectedBatchHostFragmentsFromBatchValue(MaterializerState& state, + MaterializedClass& sourceClass, + MaterializedClass& ownerClass, + ArrayRef keys, + Value payload, + Value originalOutput, + Location loc) { + if (!sourceClass.isBatch) + return false; + if (ownerClass.isBatch) + return false; + if (!hasLiveExternalUseCached(state, originalOutput)) + return false; + if (payload.getType() == originalOutput.getType()) + return false; + + auto sourceBatch = dyn_cast_or_null(originalOutput.getDefiningOp()); + if (!sourceBatch || sourceBatch.getNumResults() == 0) + return false; + + auto resultType = dyn_cast(originalOutput.getType()); + auto fragmentType = dyn_cast(payload.getType()); + if (!resultType || !resultType.hasStaticShape() || !fragmentType || !fragmentType.hasStaticShape()) + return false; + + std::optional resultIndex = getOriginalOutputResultIndex(originalOutput); + if (!resultIndex) + return false; + + FailureOr projection = getBatchResultProjectionInsert(sourceBatch, *resultIndex); + if (failed(projection)) + return false; + + std::optional laneArg = sourceBatch.getLaneArgument(); + if (!laneArg) { + sourceBatch.emitOpError("missing compute_batch lane argument while recording projected batch host fragments"); + return failure(); + } + + if (keys.size() != sourceClass.cpus.size()) { + sourceClass.op->emitError("projected batch host publication expects one producer key per materialized batch lane") + << " keyCount=" << keys.size() << " laneCount=" << sourceClass.cpus.size(); + return failure(); + } + + MessageVector messages; + messages.channelIds.reserve(sourceClass.cpus.size()); + messages.sourceCoreIds.reserve(sourceClass.cpus.size()); + messages.targetCoreIds.reserve(sourceClass.cpus.size()); + + auto checkedTargetCpu = getCheckedCoreId(ownerClass.op, ownerClass.cpus.front(), "projected batch host output target core id"); + if (failed(checkedTargetCpu)) + return failure(); + + for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "projected batch host output source core id"); + if (failed(checkedSourceCpu)) + return failure(); + messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); + (void) lane; + } + + if (failed(appendSend(state, sourceClass, payload, messages, loc))) + return failure(); + + for (auto [lane, key] : llvm::enumerate(keys)) { + if (key.instance.op != sourceBatch.getOperation() || key.resultIndex != *resultIndex || key.instance.laneCount != 1) { + sourceClass.op->emitError("projected batch host publication received an incompatible producer key") + << " laneStart=" << key.instance.laneStart << " laneCount=" << key.instance.laneCount + << " resultIndex=" << key.resultIndex; + return failure(); + } + + FailureOr> offsets = + evaluateStaticProjectionIndices(projection->getMixedOffsets(), *laneArg, key.instance.laneStart); + FailureOr> sizes = + evaluateStaticProjectionIndices(projection->getMixedSizes(), *laneArg, key.instance.laneStart); + FailureOr> strides = + evaluateStaticProjectionIndices(projection->getMixedStrides(), *laneArg, key.instance.laneStart); + if (failed(offsets) || failed(sizes) || failed(strides)) { + sourceClass.op->emitError("failed to evaluate projected batch host output slice") + << " laneStart=" << key.instance.laneStart; + return failure(); + } + + state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment { + originalOutput, + sourceClass.id, + payload, + fragmentType, + SmallVector(*offsets), + SmallVector(*sizes), + SmallVector(*strides), + key.instance.laneStart, + loc, + true, + messages.slice(lane, 1)}); + } + + return true; +} + +LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { + if (state.pendingProjectedHostOutputFragments.empty()) + return success(); + + DenseMap> byOutput; + for (PendingProjectedHostOutputFragment& fragment : state.pendingProjectedHostOutputFragments) + byOutput[fragment.originalOutput].push_back(&fragment); + + SmallVector outputs; + outputs.reserve(byOutput.size()); + for (const auto& entry : byOutput) + outputs.push_back(entry.first); + llvm::sort(outputs, [](Value lhs, Value rhs) { + return reinterpret_cast(lhs.getAsOpaquePointer()) + < reinterpret_cast(rhs.getAsOpaquePointer()); + }); + + for (Value originalOutput : outputs) { + auto ownerIt = state.hostOutputOwners.find(originalOutput); + if (ownerIt == state.hostOutputOwners.end()) { + Operation* anchor = originalOutput.getDefiningOp() ? originalOutput.getDefiningOp() : state.func.getOperation(); + return anchor->emitError("missing host owner for projected host output fragments"); + } + + MaterializedClass& ownerClass = state.classes[ownerIt->second]; + if (ownerClass.isBatch) + return ownerClass.op->emitError("projected scalar host output finalization expected a scalar host owner"); + + auto resultType = dyn_cast(originalOutput.getType()); + if (!resultType || !resultType.hasStaticShape()) + return ownerClass.op->emitError("projected host output must have static ranked tensor type"); + + SmallVector& fragments = byOutput[originalOutput]; + llvm::sort(fragments, [](const PendingProjectedHostOutputFragment* lhs, + const PendingProjectedHostOutputFragment* rhs) { + if (lhs->sourceLane != rhs->sourceLane) + return lhs->sourceLane < rhs->sourceLane; + if (lhs->sourceClass != rhs->sourceClass) + return lhs->sourceClass < rhs->sourceClass; + return std::lexicographical_compare(lhs->offsets.begin(), + lhs->offsets.end(), + rhs->offsets.begin(), + rhs->offsets.end()); + }); + + state.rewriter.setInsertionPoint(ownerClass.body->getTerminator()); + Location loc = fragments.front()->loc; + Value assembled = tensor::EmptyOp::create( + state.rewriter, loc, resultType.getShape(), resultType.getElementType()) + .getResult(); + + for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) { + Value fragment = fragmentRecord->fragment; + MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass]; + + if (fragmentRecord->sourceClass != ownerClass.id) { + MessageVector messages; + if (fragmentRecord->sendAlreadyEmitted) { + messages = fragmentRecord->messages; + } else { + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, + sourceClass.cpus.front(), + "projected host output source core id"); + auto checkedTargetCpu = getCheckedCoreId(ownerClass.op, + ownerClass.cpus.front(), + "projected host output target core id"); + if (failed(checkedSourceCpu) || failed(checkedTargetCpu)) + return failure(); + messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); + if (failed(appendSend(state, sourceClass, fragment, messages, fragmentRecord->loc))) + return failure(); + } + fragment = appendReceive(state, ownerClass, fragmentRecord->fragmentType, messages, fragmentRecord->loc); + } else { + FailureOr localFragment = materializeTensorValueForMaterializedClassUse( + state, + ownerClass, + fragment, + ownerClass.op, + "projected host output assembly tried to reuse a non-local fragment tensor"); + if (failed(localFragment)) + return failure(); + fragment = *localFragment; + } + + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + offsets.reserve(fragmentRecord->offsets.size()); + sizes.reserve(fragmentRecord->sizes.size()); + strides.reserve(fragmentRecord->strides.size()); + for (auto [offset, size, stride] : llvm::zip(fragmentRecord->offsets, + fragmentRecord->sizes, + fragmentRecord->strides)) { + offsets.push_back(state.rewriter.getIndexAttr(offset)); + sizes.push_back(state.rewriter.getIndexAttr(size)); + strides.push_back(state.rewriter.getIndexAttr(stride)); + } + + state.rewriter.setInsertionPoint(ownerClass.body->getTerminator()); + assembled = tensor::InsertSliceOp::create( + state.rewriter, fragmentRecord->loc, fragment, assembled, offsets, sizes, strides) + .getResult(); + } + + if (failed(setHostOutputValue(state, ownerClass, originalOutput, assembled))) + return failure(); + } + + return success(); +} + +FailureOr resolveInputValue(MaterializerState& state, + MaterializedClass& targetClass, + Value input, + const ComputeInstance& consumerInstance, + CloneIndexingContext indexing) { + auto rejectNonLocalResolvedValue = [&](Value resolved) -> FailureOr { + if (!isTensorValueDefinedInDifferentMaterializedClass(resolved, targetClass)) + return resolved; + + std::optional producer = getInputRequestProducerKey(input, consumerInstance); + emitNonLocalMaterializedClassValueDiagnostic(consumerInstance.op, + targetClass, + "input resolution tried to reuse a tensor from another materialized class", + resolved, + producer); + return failure(); + }; + + if (isConstantLike(input)) + return input; + + if (std::optional producer = getInputRequestProducerKey(input, consumerInstance)) { + if (indexing.runSlotIndex) { + if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) { + FailureOr received = materializeIndexedBatchRunReceive( + state, targetClass, *indexedRun, *indexing.runSlotIndex, consumerInstance.op->getLoc()); + if (failed(received)) + return failure(); + return rejectNonLocalResolvedValue(*received); + } + } + + if (std::optional value = state.availableValues.lookup(state, *producer, targetClass.id)) + return rejectNonLocalResolvedValue(*value); + + + if (IndexedBatchRunValue* indexedRun = state.availableValues.lookupIndexedBatchRun(*producer, targetClass.id)) { + size_t laneCount = targetClass.cpus.size(); + for (auto [slotIndex, slot] : llvm::enumerate(indexedRun->slots)) { + if (!llvm::is_contained(slot.keys, *producer)) + continue; + + MessageVector messages = indexedRun->messages.slice(slotIndex * laneCount, laneCount); + Value received = + appendReceive(state, targetClass, indexedRun->fragmentType, messages, consumerInstance.op->getLoc()); + for (ProducerKey slotKey : slot.keys) + state.availableValues.record(slotKey, targetClass.id, received); + return rejectNonLocalResolvedValue(received); + } + } + + if (isWholeBatchProducerKey(*producer)) { + FailureOr wholeBatch = + materializeWholeBatchInput(state, targetClass, *producer, input.getType(), consumerInstance.op->getLoc()); + if (failed(wholeBatch)) + consumerInstance.op->emitError("failed to materialize whole-batch input") + << " from '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart + << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; + if (failed(wholeBatch)) + return failure(); + return rejectNonLocalResolvedValue(*wholeBatch); + } + + consumerInstance.op->emitError("failed to resolve producer value") + << " from op '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart + << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; + return failure(); + } + + if (isTensorValueDefinedInDifferentMaterializedClass(input, targetClass)) { + emitNonLocalMaterializedClassValueDiagnostic( + consumerInstance.op, + targetClass, + "input resolution tried to append a tensor from another materialized class as a normal input", + input); + return failure(); + } + + return appendInput(state, targetClass, input); +} + +bool hasProjectedInputReplacement(MaterializerState& state, + SpatComputeBatch batch, + unsigned inputIndex, + ClassId classId) { + std::optional match = getProjectedInputSliceMatch(state, batch, inputIndex); + if (!match) + return false; + + auto replacementIt = state.projectedExtractReplacements.find(match->extract.getOperation()); + if (replacementIt == state.projectedExtractReplacements.end()) + return false; + + return replacementIt->second.find(classId) != replacementIt->second.end(); +} + +void mapWeights(MaterializerState& state, + MaterializedClass& targetClass, + const ComputeInstance& instance, + IRMapping& mapper) { + Operation* op = instance.op; + if (auto compute = dyn_cast(op)) { + for (auto [index, weight] : llvm::enumerate(compute.getWeights())) { + auto weightArg = compute.getWeightArgument(index); + assert(weightArg && "expected compute weight block argument"); + mapper.map(*weightArg, appendWeight(state, targetClass, weight)); + } + return; + } + + auto batch = cast(op); + for (auto [index, weight] : llvm::enumerate(batch.getWeights())) { + auto weightArg = batch.getWeightArgument(index); + assert(weightArg && "expected compute_batch weight block argument"); + mapper.map(*weightArg, appendWeight(state, targetClass, weight)); + } +} + +LogicalResult mapInputs(MaterializerState& state, + MaterializedClass& targetClass, + const ComputeInstance& instance, + IRMapping& mapper, + CloneIndexingContext indexing) { + auto mapResolvedInput = [&](Value resolved) -> FailureOr { + return materializeTensorValueForMaterializedClassUse( + state, + targetClass, + resolved, + targetClass.op, + "input mapping tried to reuse a tensor from another materialized class"); + }; + + Operation* op = instance.op; + if (auto compute = dyn_cast(op)) { + for (auto [index, input] : llvm::enumerate(compute.getInputs())) { + FailureOr mapped = resolveInputValue(state, targetClass, input, instance, indexing); + if (failed(mapped)) { + std::optional producer = getInputRequestProducerKey(input, instance); + auto diagnostic = compute.emitOpError("failed to resolve materialized compute input") << " #" << index; + if (producer) { + diagnostic << " from '" << producer->instance.op->getName() << "' laneStart=" << producer->instance.laneStart + << " laneCount=" << producer->instance.laneCount << " resultIndex=" << producer->resultIndex; + } + return failure(); + } + auto inputArg = compute.getInputArgument(index); + if (!inputArg) + return compute.emitOpError("expected compute input block argument while materializing inputs"); + FailureOr remapped = mapResolvedInput(*mapped); + if (failed(remapped)) { + emitNonLocalMaterializedClassValueDiagnostic(compute, + targetClass, + "mapInputs tried to append a tensor from another materialized class", + *mapped, + getInputRequestProducerKey(input, instance)); + return failure(); + } + mapper.map(*inputArg, *remapped); + } + return success(); + } + + auto batch = cast(op); + for (auto [index, input] : llvm::enumerate(batch.getInputs())) { + if (hasProjectedInputReplacement(state, batch, static_cast(index), targetClass.id)) + continue; + + FailureOr mapped = failure(); + if (std::optional wholeBatchProducer = getWholeBatchProducerKeyForDirectBatchResult(input); + wholeBatchProducer && !canUseProjectedLaneInput(state, batch, static_cast(index), input, instance)) { + mapped = materializeWholeBatchInput( + state, targetClass, *wholeBatchProducer, input.getType(), batch.getOperation()->getLoc()); + if (failed(mapped)) + return batch.emitOpError("failed to materialize whole-batch compute_batch input") + << " #" << index << " from '" << wholeBatchProducer->instance.op->getName() + << "' laneStart=" << wholeBatchProducer->instance.laneStart + << " laneCount=" << wholeBatchProducer->instance.laneCount + << " resultIndex=" << wholeBatchProducer->resultIndex; + } else { + mapped = resolveInputValue(state, targetClass, input, instance, indexing); + if (failed(mapped)) + return batch.emitOpError("failed to resolve materialized compute_batch input"); + } + + auto inputArg = batch.getInputArgument(index); + if (!inputArg) + return batch.emitOpError("expected compute_batch input block argument while materializing inputs"); + FailureOr remapped = mapResolvedInput(*mapped); + if (failed(remapped)) { + emitNonLocalMaterializedClassValueDiagnostic(batch, + targetClass, + "mapInputs tried to append a tensor from another materialized class", + *mapped, + getInputRequestProducerKey(input, instance)); + return failure(); + } + mapper.map(*inputArg, *remapped); + } + return success(); +} + +SmallVector collectMappedBatchOutputs(SpatComputeBatch batch, IRMapping& mapper) { + SmallVector outputs(batch.getNumResults(), Value {}); + auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); + if (!inParallel) + return outputs; + + for (Operation& op : inParallel.getRegion().front()) { + auto insert = dyn_cast(&op); + if (!insert) + continue; + + auto outputArg = dyn_cast(insert.getDest()); + if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) + continue; + + auto firstOutputArg = batch.getOutputArgument(0); + if (!firstOutputArg) + return outputs; + unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); + if (resultIndex >= outputs.size()) + continue; + outputs[resultIndex] = mapper.lookupOrDefault(insert.getSource()); + } + + return outputs; +} + +SmallVector collectBatchOutputFragmentTypes(SpatComputeBatch batch) { + SmallVector types(batch.getNumResults(), Type {}); + auto inParallel = dyn_cast_or_null(batch.getBody().front().getTerminator()); + if (!inParallel) + return types; + + auto firstOutputArg = batch.getOutputArgument(0); + if (!firstOutputArg) + return types; + + for (Operation& op : inParallel.getRegion().front()) { + auto insert = dyn_cast(&op); + if (!insert) + continue; + + auto outputArg = dyn_cast(insert.getDest()); + if (!outputArg || outputArg.getOwner() != &batch.getBody().front()) + continue; + + unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); + if (resultIndex >= types.size()) + continue; + + types[resultIndex] = insert.getSource().getType(); + } + + return types; +} + +SmallVector& getBatchOutputFragmentTypesCached(MaterializerState& state, SpatComputeBatch batch) { + auto [it, inserted] = state.batchOutputFragmentTypesCache.try_emplace(batch.getOperation(), SmallVector {}); + if (inserted) + it->second = collectBatchOutputFragmentTypes(batch); + return it->second; +} + +ArrayRef getComputeInstanceOutputValuesCached(MaterializerState& state, ComputeInstance instance) { + auto [it, inserted] = state.computeInstanceOutputsCache.try_emplace(instance, SmallVector {}); + if (inserted) + it->second = getComputeInstanceOutputValues(instance); + return it->second; +} + +std::optional lookupProjectedExtractReplacement(MaterializerState& state, + MaterializedClass& targetClass, + tensor::ExtractSliceOp extract) { + auto replacementIt = state.projectedExtractReplacements.find(extract.getOperation()); + if (replacementIt == state.projectedExtractReplacements.end()) + return std::nullopt; + + auto classIt = replacementIt->second.find(targetClass.id); + if (classIt == replacementIt->second.end()) + return std::nullopt; + + return classIt->second; +} + +LogicalResult applyProjectedExtractReplacementsInClonedOp(MaterializerState& state, + MaterializedClass& targetClass, + Operation& originalOp, + Operation& clonedOp, + CloneIndexingContext indexing, + IRMapping& mapper) { + if (auto originalExtract = dyn_cast(&originalOp)) { + if (std::optional replacement = + lookupProjectedExtractReplacement(state, targetClass, originalExtract)) { + auto clonedExtract = dyn_cast(&clonedOp); + if (!clonedExtract) + return targetClass.op->emitError("projected replacement lost extract structure during cloning"); + + state.rewriter.setInsertionPoint(clonedExtract); + FailureOr projected = materializeProjectedExtractReplacement( + state, targetClass, clonedExtract, *replacement, indexing.projectionSlotIndex, &mapper); + if (failed(projected)) + return failure(); + + clonedExtract.getResult().replaceAllUsesWith(*projected); + state.rewriter.eraseOp(clonedExtract); + return success(); + } + } + + if (originalOp.getNumRegions() != clonedOp.getNumRegions()) + return targetClass.op->emitError("projected replacement traversal found non-isomorphic cloned regions"); + + for (auto [originalRegion, clonedRegion] : llvm::zip(originalOp.getRegions(), clonedOp.getRegions())) { + if (std::distance(originalRegion.begin(), originalRegion.end()) + != std::distance(clonedRegion.begin(), clonedRegion.end())) + return targetClass.op->emitError("projected replacement traversal found non-isomorphic cloned blocks"); + + for (auto [originalBlock, clonedBlock] : llvm::zip(originalRegion.getBlocks(), clonedRegion.getBlocks())) { + auto originalIt = originalBlock.begin(); + auto clonedIt = clonedBlock.begin(); + while (originalIt != originalBlock.end() && clonedIt != clonedBlock.end()) { + Operation& originalNestedOp = *originalIt++; + Operation* currentClonedOp = &*clonedIt++; + if (failed(applyProjectedExtractReplacementsInClonedOp( + state, targetClass, originalNestedOp, *currentClonedOp, indexing, mapper))) + return failure(); + } + if (originalIt != originalBlock.end() || clonedIt != clonedBlock.end()) + return targetClass.op->emitError("projected replacement traversal found mismatched cloned operations"); + } + } + + return success(); +} + +LogicalResult mapClonedRegionBlockArguments(Operation& originalOp, Operation& clonedOp, IRMapping& mapper) { + if (originalOp.getNumRegions() != clonedOp.getNumRegions()) + return clonedOp.emitError("cloned operation has a different number of regions than the source operation"); + + for (auto [originalRegion, clonedRegion] : llvm::zip(originalOp.getRegions(), clonedOp.getRegions())) { + if (std::distance(originalRegion.begin(), originalRegion.end()) + != std::distance(clonedRegion.begin(), clonedRegion.end())) + return clonedOp.emitError("cloned operation has a different number of blocks than the source operation"); + + for (auto [originalBlock, clonedBlock] : llvm::zip(originalRegion.getBlocks(), clonedRegion.getBlocks())) { + if (originalBlock.getNumArguments() != clonedBlock.getNumArguments()) + return clonedOp.emitError("cloned operation block has a different number of arguments than the source block"); + + for (auto [originalArg, clonedArg] : llvm::zip(originalBlock.getArguments(), clonedBlock.getArguments())) + if (!mapper.contains(originalArg)) + mapper.map(originalArg, clonedArg); + + if (std::distance(originalBlock.begin(), originalBlock.end()) != std::distance(clonedBlock.begin(), clonedBlock.end())) + return clonedOp.emitError("cloned operation block has a different number of operations than the source block"); + + auto originalIt = originalBlock.begin(); + auto clonedIt = clonedBlock.begin(); + while (originalIt != originalBlock.end()) { + Operation& originalNestedOp = *originalIt++; + Operation& clonedNestedOp = *clonedIt++; + if (failed(mapClonedRegionBlockArguments(originalNestedOp, clonedNestedOp, mapper))) + return failure(); + } + } + } + + return success(); +} + +LogicalResult cloneComputeTemplateBody(MaterializerState& state, + MaterializedClass& targetClass, + const ComputeInstance& instance, + IRMapping& mapper, + CloneIndexingContext indexing) { + Block& sourceBlock = getComputeInstanceTemplateBlock(instance); + for (Operation& op : sourceBlock.without_terminator()) { + if (auto extract = dyn_cast(&op)) { + if (std::optional replacement = + lookupProjectedExtractReplacement(state, targetClass, extract)) { + FailureOr projected = materializeProjectedExtractReplacement( + state, targetClass, extract, *replacement, indexing.projectionSlotIndex, &mapper); + if (failed(projected)) + return failure(); + + mapper.map(extract.getResult(), *projected); + continue; + } + } + + for (Value operand : op.getOperands()) { + if (mapper.contains(operand)) + continue; + + FailureOr localized = localizeMaterializedClassOperand( + state, + targetClass, + operand, + &op, + "cloneComputeTemplateBody tried to reuse a tensor from another materialized class", + "cloneComputeTemplateBody produced an unsupported external non-tensor operand", + &mapper); + if (failed(localized)) + return failure(); + if (*localized != operand) + mapper.map(operand, *localized); + } + + Operation* cloned = state.rewriter.clone(op, mapper); + if (failed(mapClonedRegionBlockArguments(op, *cloned, mapper))) + return failure(); + if (failed(localizeCapturesInClonedOp(state, targetClass, *cloned, &mapper))) + return failure(); + if (op.getNumRegions() != 0 + && failed(applyProjectedExtractReplacementsInClonedOp(state, targetClass, op, *cloned, indexing, mapper))) + return failure(); + for (auto [oldResult, newResult] : llvm::zip(op.getResults(), cloned->getResults())) + mapper.map(oldResult, newResult); + } + + return success(); +} + +FailureOr materializeProjectedExtractReplacement(MaterializerState& state, + MaterializedClass& targetClass, + tensor::ExtractSliceOp extract, + const ProjectedExtractReplacement& replacement, + std::optional projectionSlotIndex, + IRMapping* mapper) { + if (failed(verifyProjectedFragmentLayout(targetClass.op, replacement.layout))) + return failure(); + + FailureOr localizedPayload = materializeTensorValueForMaterializedClassUse( + state, + targetClass, + replacement.payload, + targetClass.op, + "projected extract replacement tried to reuse a tensor from another materialized class", + std::nullopt, + mapper); + if (failed(localizedPayload)) + return failure(); + Value payload = *localizedPayload; + + if (replacement.layout.payloadFragmentCount == 1) + return payload; + + if (replacement.layout.payloadFragmentCount < replacement.layout.fragmentsPerLogicalSlot) + return targetClass.op->emitError("projected replacement payload is smaller than one logical slot"); + + Value intraSlotFragmentIndex = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + const auto linearizeProjectedLoopIndices = [&]() -> FailureOr { + if (replacement.layout.loopTripCounts.empty()) + return intraSlotFragmentIndex; + + SmallVector surroundingLoops; + for (Operation* current = extract->getParentOp(); current; current = current->getParentOp()) { + if (auto loop = dyn_cast(current)) + surroundingLoops.push_back(loop); + if (current == targetClass.op) + break; + } + std::reverse(surroundingLoops.begin(), surroundingLoops.end()); + + if (surroundingLoops.size() != replacement.layout.loopTripCounts.size()) + return targetClass.op->emitError("projected replacement loop structure does not match the collected descriptor"); + + Value linearizedIndex = intraSlotFragmentIndex; + for (auto [index, loop] : llvm::enumerate(surroundingLoops)) { + FailureOr localizedIv = + rematerializeIndexValueInClass(state, targetClass, loop.getInductionVar(), extract.getLoc(), mapper); + if (failed(localizedIv)) + return failure(); + Value iv = *localizedIv; + Value lowerBound = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopLowerBounds[index]); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopSteps[index]); + Value tripCount = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.loopTripCounts[index]); + + Value normalized = arith::SubIOp::create(state.rewriter, extract.getLoc(), iv, lowerBound).getResult(); + if (replacement.layout.loopSteps[index] != 1) + normalized = arith::DivUIOp::create(state.rewriter, extract.getLoc(), normalized, step).getResult(); + linearizedIndex = arith::MulIOp::create(state.rewriter, extract.getLoc(), linearizedIndex, tripCount).getResult(); + linearizedIndex = + arith::AddIOp::create(state.rewriter, extract.getLoc(), linearizedIndex, normalized).getResult(); + } + return linearizedIndex; + }; + + FailureOr linearizedIndex = linearizeProjectedLoopIndices(); + if (failed(linearizedIndex)) + return failure(); + intraSlotFragmentIndex = *linearizedIndex; + + const auto computeProjectedPayloadFragmentIndex = [&]() -> FailureOr { + if (replacement.layout.payloadFragmentCount == replacement.layout.fragmentsPerLogicalSlot) { + if (replacement.layout.loopTripCounts.empty() && replacement.layout.fragmentsPerLogicalSlot != 1) + return targetClass.op->emitError("projected replacement is missing loop metadata for packed logical slot"); + return intraSlotFragmentIndex; + } + + if (!projectionSlotIndex) + return targetClass.op->emitError("packed projected extract replacement requires a fragment slot index"); + + FailureOr localProjectionSlotIndex = + rematerializeIndexValueInClass(state, targetClass, *projectionSlotIndex, extract.getLoc(), mapper); + if (failed(localProjectionSlotIndex)) + return failure(); + + Value fragmentsPerLogicalSlot = + getOrCreateIndexConstant(state.constantFolder, targetClass.op, replacement.layout.fragmentsPerLogicalSlot); + Value base = + arith::MulIOp::create(state.rewriter, extract.getLoc(), *localProjectionSlotIndex, fragmentsPerLogicalSlot) + .getResult(); + return arith::AddIOp::create(state.rewriter, extract.getLoc(), base, intraSlotFragmentIndex).getResult(); + }; + + FailureOr packedFragmentIndex = computeProjectedPayloadFragmentIndex(); + if (failed(packedFragmentIndex)) + return failure(); + + FailureOr packedOffset = scaleIndexByDim0SizeInClass( + state, targetClass, *packedFragmentIndex, replacement.layout.fragmentType.getDimSize(0), extract.getLoc()); + if (failed(packedOffset)) + return failure(); + return createDim0ExtractSliceInClass( + state, targetClass, extract.getLoc(), payload, *packedOffset, replacement.layout.fragmentType.getDimSize(0)); +} + +FailureOr materializeIndexedBatchRunReceive(MaterializerState& state, + MaterializedClass& targetClass, + IndexedBatchRunValue& run, + Value runSlotIndex, + Location loc) { + if (!targetClass.isBatch) + return targetClass.op->emitError("indexed batch run receive requires a batch target class"); + if (failed(run.messages.verify(targetClass.op))) + return failure(); + + Value flatIndex = createBatchRunFlatIndex(state, targetClass, runSlotIndex, loc); + std::optional preferredPeriod = static_cast(targetClass.cpus.size()); + Value channelId = createIndexedChannelId(state, targetClass.op, run.messages, flatIndex, loc, preferredPeriod); + Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, run.messages, flatIndex, loc, preferredPeriod); + Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, run.messages, flatIndex, loc, preferredPeriod); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + return SpatChannelReceiveOp::create(state.rewriter, loc, run.fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); +} + +LogicalResult localizeCapturesInOperationTree(MaterializerState& state, + MaterializedClass& targetClass, + Operation& root, + StringRef tensorContext, + StringRef genericContext, + IRMapping* mapper = nullptr) { + WalkResult walkResult = root.walk([&](Operation* nestedOp) -> WalkResult { + for (OpOperand& operand : nestedOp->getOpOperands()) { + Value current = operand.get(); + if (isValueLegalInMaterializedClassBody(current, targetClass)) + continue; + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPoint(nestedOp); + FailureOr localized = + localizeMaterializedClassOperand(state, targetClass, current, nestedOp, tensorContext, genericContext, mapper); + if (failed(localized)) { + InFlightDiagnostic diagnostic = targetClass.op->emitError( + "RAPTOR_MATERIALIZER_DEBUG failed to localize cloned scheduled-body operand"); + diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName() + << "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType() + << " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp)) + << "\" offendingOperands=\"" << formatMaterializerOperandListInline(nestedOp, targetClass) + << "\" parentChain=\"" << formatMaterializerParentChainInline(nestedOp) << "\""; + diagnostic.attachNote(nestedOp->getLoc()) << "offending nested operation"; + attachMaterializerOperationPrintNote(diagnostic, nestedOp, "RAPTOR_MATERIALIZER_DEBUG offending nested operation IR"); + attachMaterializerOperandListNote(diagnostic, nestedOp, targetClass, "RAPTOR_MATERIALIZER_DEBUG offending nested operation operands"); + attachMaterializerParentChainNote(diagnostic, nestedOp, "RAPTOR_MATERIALIZER_DEBUG offending nested operation parent chain"); + attachMaterializerValueOriginNote(diagnostic, current, "offending operand"); + attachMaterializerOperationPrintNote(diagnostic, targetClass.op, "RAPTOR_MATERIALIZER_DEBUG target materialized op"); + attachMaterializedClassBodySummary(diagnostic, targetClass); + return WalkResult::interrupt(); + } + operand.set(*localized); + } + return WalkResult::advance(); + }); + + return walkResult.wasInterrupted() ? failure() : success(); +} + +LogicalResult localizeCapturesInClonedOp(MaterializerState& state, + MaterializedClass& targetClass, + Operation& clonedOp, + IRMapping* mapper) { + return localizeCapturesInOperationTree( + state, + targetClass, + clonedOp, + "cloneComputeTemplateBody tried to reuse a tensor from another materialized class", + "cloneComputeTemplateBody produced an unsupported external non-tensor operand", + mapper); +} + +LogicalResult localizeAllScheduledBodyCaptures(MaterializerState& state, MaterializedClass& targetClass) { + SmallVector bodyOps; + for (Operation& op : *targetClass.body) + op.walk([&](Operation* nestedOp) { bodyOps.push_back(nestedOp); }); + + for (Operation* nestedOp : bodyOps) { + if (nestedOp->getBlock() == nullptr) + continue; + for (OpOperand& operand : nestedOp->getOpOperands()) { + Value current = operand.get(); + if (isValueLegalInMaterializedClassBody(current, targetClass)) + continue; + + OpBuilder::InsertionGuard guard(state.rewriter); + state.rewriter.setInsertionPoint(nestedOp); + FailureOr localized = localizeMaterializedClassOperand( + state, + targetClass, + current, + nestedOp, + "final scheduled body capture localization tried to reuse a tensor from another materialized class", + "final scheduled body capture localization found an unsupported external non-tensor operand"); + if (failed(localized)) { + InFlightDiagnostic diagnostic = targetClass.op->emitError( + "RAPTOR_MATERIALIZER_DEBUG failed to localize final scheduled-body operand"); + diagnostic << " targetClass=" << targetClass.id << " nestedOp='" << nestedOp->getName() + << "' operand#" << operand.getOperandNumber() << " operandType=" << current.getType() + << " offendingIR=\"" << truncateMaterializerDebugString(stringifyOperationForMaterializerDebug(nestedOp)) + << "\" offendingOperands=\"" << formatMaterializerOperandListInline(nestedOp, targetClass) + << "\" parentChain=\"" << formatMaterializerParentChainInline(nestedOp) << "\""; + diagnostic.attachNote(nestedOp->getLoc()) << "offending nested operation"; + attachMaterializerValueOriginNote(diagnostic, current, "offending operand"); + attachMaterializedClassBodySummary(diagnostic, targetClass); + return failure(); + } + operand.set(*localized); + } + } + + return success(); +} + +FailureOr> cloneInstanceBody(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef peers, + CloneIndexingContext indexing) { + assert(!peers.empty() && "expected at least one peer instance"); + const ComputeInstance& instance = peers.front(); + Operation* sourceOp = instance.op; + Location loc = sourceOp->getLoc(); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + + IRMapping mapper; + if (auto batch = dyn_cast(sourceOp)) { + for (const ComputeInstance& peer : peers) { + if (peer.op != sourceOp) { + sourceOp->emitError("equivalence class slot contains different source compute_batch operations"); + return failure(); + } + } + auto laneArg = batch.getLaneArgument(); + if (!laneArg) { + sourceOp->emitError("expected source compute_batch lane block argument"); + return failure(); + } + mapper.map(*laneArg, createOriginalLaneValue(state, targetClass, peers, loc)); + } + + OpBuilder::InsertPoint cloneInsertionPoint = state.rewriter.saveInsertionPoint(); + + mapWeights(state, targetClass, instance, mapper); + if (failed(mapInputs(state, targetClass, instance, mapper, indexing))) + return failure(); + + state.rewriter.restoreInsertionPoint(cloneInsertionPoint); + if (failed(cloneComputeTemplateBody(state, targetClass, instance, mapper, indexing))) + return failure(); + + if (auto compute = dyn_cast(sourceOp)) { + Block& sourceBlock = getComputeInstanceTemplateBlock(instance); + auto yield = dyn_cast_or_null(sourceBlock.getTerminator()); + if (!yield) { + compute.emitOpError("expected spat.yield terminator while materializing compute"); + return failure(); + } + + SmallVector outputs; + outputs.reserve(yield.getNumOperands()); + for (Value yielded : yield.getOutputs()) + outputs.push_back(mapper.lookupOrDefault(yielded)); + return outputs; + } + + auto batch = cast(sourceOp); + if (batch.getNumResults() == 0) + return SmallVector {}; + + SmallVector outputs = collectMappedBatchOutputs(batch, mapper); + for (Value output : outputs) + if (!output) { + batch.emitOpError("failed to recover yielded per-lane value for compute_batch result"); + return failure(); + } + return outputs; +} + +bool sameDestinationClasses(ArrayRef lhs, ArrayRef rhs) { + if (lhs.size() != rhs.size()) + return false; + for (auto [lhsClass, rhsClass] : llvm::zip(lhs, rhs)) + if (lhsClass != rhsClass) + return false; + return true; +} + +SmallVector +collectDestinationClassesForRun(MaterializerState& state, ArrayRef run, size_t resultIndex) { + SmallVector destinations; + + for (const MaterializationRunSlot& slot : run) { + for (const ComputeInstance& peer : slot.peers) { + ProducerKey key {peer, resultIndex}; + for (ClassId destinationClass : getDestinationClasses(state, key)) + if (!llvm::is_contained(destinations, destinationClass)) + destinations.push_back(destinationClass); + } + } + + llvm::sort(destinations); + return destinations; +} + +SmallVector groupBatchRunOutputsByDestination(MaterializerState& state, + ArrayRef run) { + assert(!run.empty() && "expected non-empty materialization run"); + assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); + + SmallVector groups; + ArrayRef outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front()); + + for (auto [resultIndex, output] : llvm::enumerate(outputs)) { + SmallVector destinations = collectDestinationClassesForRun(state, run, resultIndex); + + auto existingGroup = llvm::find_if(groups, [&](const OutputDestinationGroup& group) { + return sameDestinationClasses(group.destinationClasses, destinations); + }); + + if (existingGroup != groups.end()) { + existingGroup->resultIndices.push_back(resultIndex); + continue; + } + + OutputDestinationGroup group; + group.resultIndices.push_back(resultIndex); + group.destinationClasses = std::move(destinations); + groups.push_back(std::move(group)); + } + + return groups; +} + +FailureOr getPackedRunTensorType(Type elementType, size_t runSize) { + auto tensorType = dyn_cast(elementType); + if (!tensorType || !tensorType.hasStaticShape() || tensorType.getRank() == 0) + return failure(); + + SmallVector shape(tensorType.getShape()); + shape[0] *= static_cast(runSize); + return RankedTensorType::get(shape, tensorType.getElementType(), tensorType.getEncoding()); +} + +LogicalResult registerDeferredLocalPackedRunValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef keys, + Type fragmentType, + Location loc) { + if (keys.empty()) + return success(); + + auto rankedFragmentType = dyn_cast(fragmentType); + if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) + return materializedClass.op->emitError("deferred local packed run expects static ranked fragment type"); + + Operation* sourceOp = keys.front().instance.op; + size_t resultIndex = keys.front().resultIndex; + + for (ProducerKey key : keys) { + if (key.instance.op != sourceOp || key.resultIndex != resultIndex) + return materializedClass.op->emitError("deferred local packed run expects one producer result"); + + if (key.instance.laneCount != 1) + return materializedClass.op->emitError("deferred local packed run expects one lane per fragment"); + } + + PackedScalarRunValue packedRun; + packedRun.targetClass = materializedClass.id; + packedRun.sourceOp = sourceOp; + packedRun.resultIndex = resultIndex; + packedRun.kind = PackedScalarRunKind::DeferredLocalCompute; + packedRun.fragmentType = rankedFragmentType; + + packedRun.slots.reserve(keys.size()); + for (ProducerKey key : keys) { + PackedScalarRunSlot slot; + slot.keys.push_back(key); + packedRun.slots.push_back(std::move(slot)); + } + + state.availableValues.recordPackedRun(std::move(packedRun)); + return success(); +} + +LogicalResult registerPackedRunValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef keys, + Value packed, + Type fragmentType, + Location loc) { + if (keys.empty()) + return success(); + + FailureOr expectedPackedType = getPackedRunTensorType(fragmentType, keys.size()); + if (failed(expectedPackedType)) + return materializedClass.op->emitError("packed run registration expects static ranked fragment type"); + + if (packed.getType() != *expectedPackedType) + return materializedClass.op->emitError("packed run value has unexpected tensor type"); + + auto rankedFragmentType = dyn_cast(fragmentType); + if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) + return materializedClass.op->emitError("packed run registration expects static ranked fragment type"); + + Operation* sourceOp = keys.front().instance.op; + size_t resultIndex = keys.front().resultIndex; + + for (ProducerKey key : keys) { + if (key.instance.op != sourceOp || key.resultIndex != resultIndex) + return materializedClass.op->emitError("packed run registration expects one producer result"); + if (key.instance.laneCount != 1) + return materializedClass.op->emitError("packed run registration expects one lane per packed fragment"); + } + + if (std::optional contiguousKey = getContiguousProducerRangeForKeys(keys)) { + state.availableValues.record(*contiguousKey, materializedClass.id, packed); + return success(); + } + + PackedScalarRunValue packedRun; + packedRun.targetClass = materializedClass.id; + packedRun.sourceOp = sourceOp; + packedRun.resultIndex = resultIndex; + packedRun.packed = packed; + packedRun.kind = PackedScalarRunKind::Materialized; + packedRun.fragmentType = rankedFragmentType; + + packedRun.slots.reserve(keys.size()); + for (ProducerKey key : keys) { + PackedScalarRunSlot slot; + slot.keys.push_back(key); + packedRun.slots.push_back(std::move(slot)); + } + + state.availableValues.recordPackedRun(std::move(packedRun)); + return success(); +} + +LogicalResult emitPackedRunFanout(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef destinationClasses, + ArrayRef keys, + Value packed, + Type fragmentType, + Location loc) { + assert(!sourceClass.isBatch && "packed run fanout expects a scalar source class"); + + auto fanoutPlan = buildScalarSourceFanoutPlan(state, sourceClass, keys, destinationClasses, packed); + if (failed(fanoutPlan)) + return failure(); + if (failed(emitScalarSourceFanoutSends(state, sourceClass, packed, *fanoutPlan, loc))) + return failure(); + + for (const ScalarSourceReceivePlan& plan : fanoutPlan->receivePlans) { + MaterializedClass& targetClass = state.classes[plan.targetClass]; + + Value received = appendReceive(state, targetClass, plan.receiveType, plan.messages, loc); + + if (plan.projectedExtractOp) { + state.projectedExtractReplacements[plan.projectedExtractOp][plan.targetClass] = + ProjectedExtractReplacement {received, plan.projectedLayout}; + continue; + } + + if (failed(registerPackedRunValue(state, targetClass, keys, received, fragmentType, loc))) + return failure(); + } + + return success(); +} + +FailureOr> cloneBatchBodyForLane(MaterializerState& state, + MaterializedClass& targetClass, + const ComputeInstance& instance, + Value laneValue, + ArrayRef resultIndices, + CloneIndexingContext indexing) { + auto batch = dyn_cast(instance.op); + if (!batch) + return failure(); + + IRMapping mapper; + auto sourceLaneArg = batch.getLaneArgument(); + if (!sourceLaneArg) + return batch.emitOpError("expected source compute_batch lane block argument"); + + mapper.map(*sourceLaneArg, laneValue); + + OpBuilder::InsertPoint cloneInsertionPoint = state.rewriter.saveInsertionPoint(); + + mapWeights(state, targetClass, instance, mapper); + if (failed(mapInputs(state, targetClass, instance, mapper, indexing))) + return failure(); + + state.rewriter.restoreInsertionPoint(cloneInsertionPoint); + if (failed(cloneComputeTemplateBody(state, targetClass, instance, mapper, indexing))) + return failure(); + + SmallVector allOutputs = collectMappedBatchOutputs(batch, mapper); + if (allOutputs.empty() && !resultIndices.empty()) + return batch.emitOpError("failed to recover source compute_batch outputs"); + + SmallVector selectedOutputs; + selectedOutputs.reserve(resultIndices.size()); + for (size_t resultIndex : resultIndices) { + if (resultIndex >= allOutputs.size() || !allOutputs[resultIndex]) + return batch.emitOpError("failed to recover selected compute_batch output"); + selectedOutputs.push_back(allOutputs[resultIndex]); + } + + return selectedOutputs; +} + +FailureOr> materializeBatchOutputGroupLoop(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run, + const OutputDestinationGroup& group) { + assert(!run.empty() && "expected non-empty batch run"); + assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); + + Operation* sourceOp = run.front().peers.front().op; + Location loc = sourceOp->getLoc(); + + if (run.size() == 1) { + if (run.front().peers.size() != 1) + return sourceOp->emitError("scalar batch output loop expects exactly one peer in singleton slot"); + + const ComputeInstance& item = run.front().peers.front(); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value laneValue = getOrCreateIndexConstant(state.constantFolder, targetClass.op, item.laneStart); + return cloneBatchBodyForLane(state, targetClass, item, laneValue, group.resultIndices, {}); + } + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + + auto sourceBatch = cast(sourceOp); + SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); + SmallVector initValues; + for (size_t resultIndex : group.resultIndices) { + if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) + return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run"); + + Type fragmentType = fragmentTypes[resultIndex]; + FailureOr packedType = getPackedRunTensorType(fragmentType, run.size()); + if (failed(packedType)) + return sourceBatch.emitOpError("cannot materialize packed batch run for non-static ranked output"); + + initValues.push_back( + tensor::EmptyOp::create(state.rewriter, loc, packedType->getShape(), packedType->getElementType()).getResult()); + } + + SmallVector logicalLanes; + logicalLanes.reserve(run.size()); + for (const MaterializationRunSlot& slot : run) { + if (slot.peers.size() != 1) + return sourceOp->emitError("scalar batch output loop expects exactly one peer per materialization slot"); + + const ComputeInstance& item = slot.peers.front(); + if (item.op != sourceOp) + return sourceOp->emitError("materialization run contains different source operations"); + + logicalLanes.push_back(item.laneStart); + } + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange(initValues), + [&](OpBuilder&, Location, Value loopIndex, ValueRange iterArgs, SmallVectorImpl& yielded) { + Value sourceLane = createIndexedIndexValue(state, targetClass.op, logicalLanes, loopIndex, loc); + + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + run.front().peers.front(), + sourceLane, + group.resultIndices, + CloneIndexingContext {.runSlotIndex = loopIndex, .projectionSlotIndex = loopIndex}); + if (failed(produced)) + return failure(); + + yielded.reserve(produced->size()); + for (auto [outputIndex, output] : llvm::enumerate(*produced)) { + auto fragmentType = cast(output.getType()); + Value acc = iterArgs[outputIndex]; + Value firstOffset = scaleIndexByDim0Size(state, targetClass.op, loopIndex, fragmentType.getDimSize(0), loc); + yielded.push_back(createDim0InsertSlice(state, loc, output, acc, firstOffset)); + } + return success(); + }); + if (failed(loop)) + return failure(); + + SmallVector results; + results.reserve(loop->results.size()); + for (Value result : loop->results) + results.push_back(result); + return results; +} + +SmallVector getMaterializationRunSlotOutputKeys(const MaterializationRunSlot& slot, + size_t resultIndex) { + SmallVector keys; + keys.reserve(slot.peers.size()); + for (const ComputeInstance& peer : slot.peers) + keys.push_back({peer, resultIndex}); + return keys; +} + +FailureOr> +getMaterializationRunSlotPeers(MaterializerState& state, MaterializedClass& targetClass, SlotId logicalSlot) { + if (targetClass.isBatch) + return getPeerLogicalInstances(state, targetClass, logicalSlot); + + auto streamIt = state.logicalInstancesByCpu.find(targetClass.cpus.front()); + if (streamIt == state.logicalInstancesByCpu.end() || logicalSlot >= streamIt->second.size()) + return failure(); + + return SmallVector {streamIt->second[logicalSlot]}; +} + +FailureOr collectBatchMaterializationRun(MaterializerState& state, + MaterializedClass& targetClass, + SlotId startSlot, + Operation* sourceOp) { + MaterializationRun run; + + for (SlotId slot = startSlot;; ++slot) { + ClassSlotKey classSlot {targetClass.id, slot}; + if (state.materializedLogicalSlots.contains(classSlot)) + break; + + FailureOr> peers = getMaterializationRunSlotPeers(state, targetClass, slot); + if (failed(peers) || peers->empty()) + break; + + bool validSlot = true; + for (const ComputeInstance& peer : *peers) { + if (peer.op != sourceOp || !isa(peer.op)) { + validSlot = false; + break; + } + } + + if (!validSlot) + break; + + MaterializationRunSlot runSlot; + runSlot.peers = std::move(*peers); + run.push_back(std::move(runSlot)); + } + + if (run.empty()) + return failure(); + + return run; +} + +SmallVector getMaterializationRunOutputKeys(ArrayRef run, size_t resultIndex) { + SmallVector keys; + for (const MaterializationRunSlot& slot : run) + llvm::append_range(keys, getMaterializationRunSlotOutputKeys(slot, resultIndex)); + return keys; +} + +ArrayRef getFirstMaterializationRunOriginalOutputs(MaterializerState& state, + ArrayRef run) { + assert(!run.empty() && "expected non-empty materialization run"); + assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); + return getComputeInstanceOutputValuesCached(state, run.front().peers.front()); +} + +Operation* getMaterializationRunSourceOp(ArrayRef run) { + assert(!run.empty() && "expected non-empty materialization run"); + assert(!run.front().peers.empty() && "expected non-empty materialization run slot"); + return run.front().peers.front().op; +} + +Location getMaterializationRunLoc(ArrayRef run) { + return getMaterializationRunSourceOp(run)->getLoc(); +} + +bool hasMaterializationRunResultLiveExternalUse(MaterializerState& state, + ArrayRef run, + size_t resultIndex) { + for (const MaterializationRunSlot& slot : run) { + for (const ComputeInstance& peer : slot.peers) { + ArrayRef outputs = getComputeInstanceOutputValuesCached(state, peer); + if (resultIndex >= outputs.size()) + return true; + + if (hasLiveExternalUseCached(state, outputs[resultIndex])) + return true; + } + } + + return false; +} + +bool hasMaterializationRunGroupLiveExternalUse(MaterializerState& state, + ArrayRef run, + const OutputDestinationGroup& group) { + for (size_t resultIndex : group.resultIndices) + if (hasMaterializationRunResultLiveExternalUse(state, run, resultIndex)) + return true; + + return false; +} + +bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId); + +bool hasMaterializationRunGroupSameClassConsumer(MaterializerState& state, + ClassId classId, + ArrayRef run, + const OutputDestinationGroup& group) { + for (size_t resultIndex : group.resultIndices) { + for (const MaterializationRunSlot& slot : run) { + for (const ComputeInstance& peer : slot.peers) + if (hasSameClassConsumer(state, {peer, resultIndex}, classId)) + return true; + } + } + + return false; +} + +void markMaterializationRunSlots(MaterializerState& state, + ClassId classId, + SlotId startSlot, + ArrayRef run) { + for (auto slotIndex : llvm::seq(0, run.size())) + state.materializedLogicalSlots.insert({classId, startSlot + static_cast(slotIndex)}); +} + +LogicalResult materializeScalarBatchRun(MaterializerState& state, + MaterializedClass& targetClass, + SlotId startSlot, + ArrayRef run) { + assert(!targetClass.isBatch && "scalar batch run materialization expects scalar target class"); + assert(!run.empty() && "expected non-empty batch run"); + + markMaterializationRunSlots(state, targetClass.id, startSlot, run); + + SmallVector groups = groupBatchRunOutputsByDestination(state, run); + ArrayRef firstOriginalOutputs = getFirstMaterializationRunOriginalOutputs(state, run); + + auto sourceBatch = cast(getMaterializationRunSourceOp(run)); + SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); + Location loc = getMaterializationRunLoc(run); + + for (const OutputDestinationGroup& group : groups) { + if (run.size() > 1 && group.destinationClasses.empty() + && !hasMaterializationRunGroupLiveExternalUse(state, run, group) + && !hasMaterializationRunGroupSameClassConsumer(state, targetClass.id, run, group)) { + for (size_t resultIndex : group.resultIndices) { + if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) + return sourceBatch.emitOpError("failed to recover per-lane output type for deferred local packed run"); + + SmallVector keys = getMaterializationRunOutputKeys(run, resultIndex); + if (failed(registerDeferredLocalPackedRunValue(state, targetClass, keys, fragmentTypes[resultIndex], loc))) + return failure(); + } + + continue; + } + + FailureOr> packedOutputs = materializeBatchOutputGroupLoop(state, targetClass, run, group); + if (failed(packedOutputs)) + return failure(); + + for (auto [groupOutputIndex, resultIndex] : llvm::enumerate(group.resultIndices)) { + Value packed = (*packedOutputs)[groupOutputIndex]; + if (resultIndex >= fragmentTypes.size() || !fragmentTypes[resultIndex]) + return sourceBatch.emitOpError("failed to recover per-lane output type for packed batch run"); + + Type fragmentType = fragmentTypes[resultIndex]; + SmallVector keys = getMaterializationRunOutputKeys(run, resultIndex); + + auto rankedFragmentType = cast(fragmentType); + Value representativeOriginalOutput = firstOriginalOutputs[resultIndex]; + FailureOr recordedProjectedHostFragments = recordProjectedScalarHostFragmentsFromPackedRun( + state, targetClass, sourceBatch, resultIndex, run, packed, rankedFragmentType, representativeOriginalOutput, loc); + if (failed(recordedProjectedHostFragments)) + return failure(); + + if (run.size() == 1) { + if (*recordedProjectedHostFragments) { + if (failed(emitScalarSourceCommunication(state, targetClass, keys, packed, loc))) + return failure(); + continue; + } + + if (failed(emitOutputFanout(state, targetClass, keys, packed, representativeOriginalOutput, loc))) + return failure(); + continue; + } + + if (failed(emitPackedRunFanout(state, targetClass, group.destinationClasses, keys, packed, fragmentType, loc))) + return failure(); + + if (failed(registerPackedRunValue(state, targetClass, keys, packed, fragmentType, loc))) + return failure(); + + if (*recordedProjectedHostFragments) + continue; + + for (auto [runIndex, slot] : llvm::enumerate(run)) { + assert(slot.peers.size() == 1 && "scalar materialization run slot must contain exactly one peer"); + + ArrayRef originalOutputs = getComputeInstanceOutputValuesCached(state, slot.peers.front()); + Value originalOutput = originalOutputs[resultIndex]; + + if (!hasLiveExternalUseCached(state, originalOutput)) + continue; + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value fragment = getPackedSliceForRunIndex(state, targetClass.op, packed, rankedFragmentType, runIndex, loc); + + if (failed(emitHostCommunication(state, targetClass, fragment, originalOutput))) + return failure(); + } + } + } + + return success(); +} + +bool hasSameClassConsumer(MaterializerState& state, ProducerKey producerKey, ClassId classId) { + SameClassConsumerLookupKey lookupKey{producerKey.instance.op, producerKey.resultIndex, classId}; + auto it = state.sameClassConsumerIndex.find(lookupKey); + if (it == state.sameClassConsumerIndex.end()) + return false; + + for (ProducerKey existing : it->second) + if (containsProducerKey(existing, producerKey) || containsProducerKey(producerKey, existing)) + return true; + return false; +} + +bool canCompactBatchClassRun(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run) { + if (run.size() < 2) + return false; + if (run.front().peers.empty()) + return false; + + ArrayRef outputs = getComputeInstanceOutputValuesCached(state, run.front().peers.front()); + + for (auto [resultIndex, ignored] : llvm::enumerate(outputs)) { + (void) ignored; + for (const MaterializationRunSlot& slot : run) { + if (slot.peers.empty()) + return false; + + for (const ComputeInstance& peer : slot.peers) { + ArrayRef peerOutputs = getComputeInstanceOutputValuesCached(state, peer); + if (resultIndex >= peerOutputs.size()) + return false; + + Value originalOutput = peerOutputs[resultIndex]; + if (hasLiveExternalUseCached(state, originalOutput)) + return false; + + ProducerKey key {peer, resultIndex}; + if (hasSameClassConsumer(state, key, targetClass.id)) + return false; + } + } + } + + return true; +} + +Value createBatchRunFlatIndex(MaterializerState& state, MaterializedClass& targetClass, Value slotIndex, Location loc) { + auto batch = cast(targetClass.op); + auto laneArg = batch.getLaneArgument(); + assert(laneArg && "expected materialized compute_batch lane argument"); + + MLIRContext* context = state.func.getContext(); + AffineExpr d0 = getAffineDimExpr(0, context); + AffineExpr d1 = getAffineDimExpr(1, context); + + int64_t laneCount = static_cast(targetClass.cpus.size()); + AffineMap map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0, d0 * laneCount + d1); + return createOrFoldAffineApply(state.rewriter, loc, map, ValueRange {slotIndex, *laneArg}, state.func); +} + +Value createBatchClassRunSourceLane(MaterializerState& state, + MaterializedClass& targetClass, + ArrayRef run, + Value slotIndex, + Location loc) { + SmallVector sourceLanes; + sourceLanes.reserve(run.size() * targetClass.cpus.size()); + + for (auto [runSlotIndex, slot] : llvm::enumerate(run)) { + (void) runSlotIndex; + assert(slot.peers.size() == targetClass.cpus.size() && "expected one peer per materialized batch lane"); + for (const ComputeInstance& peer : slot.peers) + sourceLanes.push_back(peer.laneStart); + } + + Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); + return createIndexedIndexValue(state, + targetClass.op, + sourceLanes, + flatIndex, + loc, + static_cast(targetClass.cpus.size()), + /*allowExhaustiveTiledSearch=*/false); +} + +LogicalResult buildBatchRunSendPlans(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef run, + const OutputDestinationGroup& group, + SmallVectorImpl& plans) { + assert(sourceClass.isBatch && "batch run send planning expects a materialized batch source"); + + for (size_t resultIndex : group.resultIndices) { + for (ClassId destinationClass : group.destinationClasses) { + if (destinationClass == sourceClass.id) + return sourceClass.op->emitError("batch-target run compaction cannot handle same-class consumers"); + + MaterializedClass& targetClass = state.classes[destinationClass]; + + if (targetClass.isBatch && targetClass.cpus.size() != sourceClass.cpus.size()) + return sourceClass.op->emitError( + "cannot compact batch run communication between batch classes of different sizes"); + + BatchRunSendPlan plan; + plan.resultIndex = resultIndex; + plan.destinationClass = destinationClass; + + size_t messageCount = run.size() * sourceClass.cpus.size(); + plan.messages.channelIds.reserve(messageCount); + plan.messages.sourceCoreIds.reserve(messageCount); + plan.messages.targetCoreIds.reserve(messageCount); + + for (size_t slotIndex = 0; slotIndex < run.size(); ++slotIndex) { + for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, sourceCpu, "batch run source core id"); + if (failed(checkedSourceCpu)) + return failure(); + auto checkedTargetCpu = + getCheckedCoreId(targetClass.op, + targetClass.isBatch ? targetClass.cpus[lane] : targetClass.cpus.front(), + "batch run target core id"); + if (failed(checkedTargetCpu)) + return failure(); + plan.messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); + } + (void) slotIndex; + } + + plans.push_back(std::move(plan)); + } + } + + return success(); +} + +void appendBatchRunSend(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + const BatchRunSendPlan& plan, + Value flatIndex, + Location loc) { + assert(sourceClass.isBatch && "batch run send expects a materialized batch source"); + + std::optional preferredPeriod = static_cast(sourceClass.cpus.size()); + Value channelId = createIndexedChannelId(state, sourceClass.op, plan.messages, flatIndex, loc, preferredPeriod); + Value sourceCoreId = createIndexedSourceCoreId(state, sourceClass.op, plan.messages, flatIndex, loc, preferredPeriod); + Value targetCoreId = createIndexedTargetCoreId(state, sourceClass.op, plan.messages, flatIndex, loc, preferredPeriod); + + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); +} + +LogicalResult appendPackedScalarRunReceives(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef run, + const BatchRunSendPlan& plan, + Type fragmentType, + Location loc) { + MaterializedClass& targetClass = state.classes[plan.destinationClass]; + assert(!targetClass.isBatch && "packed scalar run receives expect a scalar target class"); + + size_t laneCount = sourceClass.cpus.size(); + size_t receiveCount = run.size() * laneCount; + + if (failed(plan.messages.verify(targetClass.op))) + return failure(); + + if (receiveCount != plan.messages.size()) + return targetClass.op->emitError("inconsistent flattened batch run receive plan"); + + auto rankedFragmentType = dyn_cast(fragmentType); + if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) + return targetClass.op->emitError("packed scalar run receive expects static ranked fragment type"); + + PackedScalarRunValue packedRun; + packedRun.targetClass = targetClass.id; + packedRun.sourceOp = run.front().peers.front().op; + packedRun.resultIndex = plan.resultIndex; + packedRun.kind = PackedScalarRunKind::DeferredReceive; + packedRun.fragmentType = rankedFragmentType; + + packedRun.messages = plan.messages; + + packedRun.slots.reserve(run.size()); + for (const MaterializationRunSlot& slot : run) { + PackedScalarRunSlot packedSlot; + packedSlot.keys = getMaterializationRunSlotOutputKeys(slot, plan.resultIndex); + packedRun.slots.push_back(std::move(packedSlot)); + } + + if (failed(validatePackedScalarRunMetadata(targetClass.op, packedRun))) + return failure(); + + state.availableValues.recordPackedRun(std::move(packedRun)); + return success(); +} + +LogicalResult recordIndexedBatchRunReceives(MaterializerState& state, + ArrayRef run, + const BatchRunSendPlan& plan, + Type fragmentType) { + MaterializedClass& targetClass = state.classes[plan.destinationClass]; + auto rankedFragmentType = dyn_cast(fragmentType); + if (!rankedFragmentType || !rankedFragmentType.hasStaticShape() || rankedFragmentType.getRank() == 0) + return targetClass.op->emitError("indexed batch run receive expects static ranked fragment type"); + + IndexedBatchRunValue indexedRun; + indexedRun.targetClass = targetClass.id; + indexedRun.sourceOp = run.front().peers.front().op; + indexedRun.resultIndex = plan.resultIndex; + indexedRun.fragmentType = rankedFragmentType; + indexedRun.messages = plan.messages; + indexedRun.slots.reserve(run.size()); + for (const MaterializationRunSlot& slot : run) { + PackedScalarRunSlot indexedSlot; + indexedSlot.keys = getMaterializationRunSlotOutputKeys(slot, plan.resultIndex); + indexedRun.slots.push_back(std::move(indexedSlot)); + } + + state.availableValues.recordIndexedBatchRun(std::move(indexedRun)); + return success(); +} + +LogicalResult appendBatchRunReceives(MaterializerState& state, + MaterializedClass& sourceClass, + ArrayRef run, + const BatchRunSendPlan& plan, + Type fragmentType, + Location loc) { + MaterializedClass& targetClass = state.classes[plan.destinationClass]; + + if (!targetClass.isBatch) + return appendPackedScalarRunReceives(state, sourceClass, run, plan, fragmentType, loc); + return recordIndexedBatchRunReceives(state, run, plan, fragmentType); +} + +LogicalResult materializeBatchClassRun(MaterializerState& state, + MaterializedClass& targetClass, + SlotId startSlot, + ArrayRef run) { + assert(targetClass.isBatch && "batch-target run materialization expects a materialized batch class"); + assert(!run.empty() && "expected non-empty batch-target run"); + + if (!canCompactBatchClassRun(state, targetClass, run)) + return failure(); + + markMaterializationRunSlots(state, targetClass.id, startSlot, run); + + SmallVector groups = groupBatchRunOutputsByDestination(state, run); + + auto sourceBatch = cast(run.front().peers.front().op); + SmallVector& fragmentTypes = getBatchOutputFragmentTypesCached(state, sourceBatch); + Location loc = sourceBatch.getLoc(); + + for (const OutputDestinationGroup& group : groups) { + SmallVector sendPlans; + if (failed(buildBatchRunSendPlans(state, targetClass, run, group, sendPlans))) + return failure(); + + Value lowerBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 0); + Value upperBound = getOrCreateIndexConstant(state.constantFolder, targetClass.op, static_cast(run.size())); + Value step = getOrCreateIndexConstant(state.constantFolder, targetClass.op, 1); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + auto loop = buildNormalizedScfFor( + state.rewriter, + loc, + lowerBound, + upperBound, + step, + ValueRange {}, + [&](OpBuilder&, Location, Value slotIndex, ValueRange, SmallVectorImpl&) { + Value sourceLane = createBatchClassRunSourceLane(state, targetClass, run, slotIndex, loc); + Value flatIndex = createBatchRunFlatIndex(state, targetClass, slotIndex, loc); + + FailureOr> produced = + cloneBatchBodyForLane(state, + targetClass, + getScheduledChunkForLogicalInstance(state, run.front().peers.front()), + sourceLane, + group.resultIndices, + CloneIndexingContext {.runSlotIndex = slotIndex, .projectionSlotIndex = slotIndex}); + if (failed(produced)) + return failure(); + + for (const BatchRunSendPlan& plan : sendPlans) { + auto resultIt = llvm::find(group.resultIndices, plan.resultIndex); + if (resultIt == group.resultIndices.end()) + return failure(); + + size_t groupOutputIndex = static_cast(std::distance(group.resultIndices.begin(), resultIt)); + appendBatchRunSend(state, targetClass, (*produced)[groupOutputIndex], plan, flatIndex, loc); + } + return success(); + }); + if (failed(loop)) + return failure(); + + for (const BatchRunSendPlan& plan : sendPlans) { + if (plan.resultIndex >= fragmentTypes.size() || !fragmentTypes[plan.resultIndex]) + return failure(); + + if (failed(appendBatchRunReceives(state, targetClass, run, plan, fragmentTypes[plan.resultIndex], loc))) + return failure(); + } + } + + return success(); +} + +LogicalResult materializeInstanceSlot(MaterializerState& state, + const ComputeInstance& instance) { + auto cpuIt = state.schedule.computeToCpuMap.find(instance); + if (cpuIt == state.schedule.computeToCpuMap.end()) + return instance.op->emitError("schedule materialization expected a CPU assignment for every compute instance"); + auto logicalRangeIt = state.scheduledInstanceToLogicalSlots.find(instance); + if (logicalRangeIt == state.scheduledInstanceToLogicalSlots.end()) + return instance.op->emitError("schedule materialization expected logical slots for every compute instance"); + + ClassId classId = state.cpuToClass.lookup(cpuIt->second); + MaterializedClass& targetClass = state.classes[classId]; + + LogicalSlotRange logicalRange = logicalRangeIt->second; + SlotId startLogicalSlot = logicalRange.start; + while (startLogicalSlot < logicalRange.start + logicalRange.count + && state.materializedLogicalSlots.contains({classId, startLogicalSlot})) { + ++startLogicalSlot; + } + if (startLogicalSlot == logicalRange.start + logicalRange.count) + return success(); + + if (isa(instance.op)) { + FailureOr run = collectBatchMaterializationRun(state, targetClass, startLogicalSlot, instance.op); + + if (succeeded(run)) { + if (!targetClass.isBatch) + return materializeScalarBatchRun(state, targetClass, startLogicalSlot, *run); + + if (succeeded(materializeBatchClassRun(state, targetClass, startLogicalSlot, *run))) + return success(); + } + } + + if (!state.materializedLogicalSlots.insert({classId, startLogicalSlot}).second) + return success(); + + FailureOr> peers = + getMaterializationRunSlotPeers(state, targetClass, startLogicalSlot); + if (failed(peers)) + return instance.op->emitError("failed to collect peer compute instances for equivalence class logical slot"); + + Value projectionSlotIndex = getOrCreateIndexConstant( + state.constantFolder, targetClass.op, static_cast(startLogicalSlot - logicalRange.start)); + FailureOr> materializedOutputs = + cloneInstanceBody(state, + targetClass, + *peers, + CloneIndexingContext {.runSlotIndex = std::nullopt, .projectionSlotIndex = projectionSlotIndex}); + if (failed(materializedOutputs)) + return failure(); + + ArrayRef originalOutputs = getComputeInstanceOutputValuesCached(state, instance); + if (materializedOutputs->size() != originalOutputs.size()) + return instance.op->emitError("materialized output count does not match original compute instance output count"); + + for (auto [resultIndex, zipped] : llvm::enumerate(llvm::zip(*materializedOutputs, originalOutputs))) { + Value materializedOutput = std::get<0>(zipped); + Value originalOutput = std::get<1>(zipped); + MaterializationRunSlot slot; + slot.peers = *peers; + SmallVector keys = getMaterializationRunSlotOutputKeys(slot, resultIndex); + if (failed(emitOutputFanout(state, targetClass, keys, materializedOutput, originalOutput, instance.op->getLoc()))) + return failure(); + } + + return success(); +} + +FailureOr createReceiveConcatLoop(MaterializerState& state, + MaterializedClass& targetClass, + RankedTensorType concatType, + RankedTensorType fragmentType, + const MessageVector& messages, + Location loc) { + assert(succeeded(messages.verify(targetClass.op)) && "message metadata is inconsistent"); + assert(!messages.empty() && "expected at least one receive"); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value init = + tensor::EmptyOp::create(state.rewriter, loc, concatType.getShape(), concatType.getElementType()).getResult(); + return emitIndexedFragmentInsertLoop( + state, + targetClass, + init, + static_cast(messages.size()), + [&](Value index) -> FailureOr { + Value channelId = createIndexedChannelId(state, targetClass.op, messages, index, loc); + Value sourceCoreId = createIndexedSourceCoreId(state, targetClass.op, messages, index, loc); + Value targetCoreId = createIndexedTargetCoreId(state, targetClass.op, messages, index, loc); + return SpatChannelReceiveOp::create(state.rewriter, loc, fragmentType, channelId, sourceCoreId, targetCoreId) + .getOutput(); + }, + [&](Value index) -> FailureOr { + return scaleIndexByDim0SizeInClass(state, targetClass, index, fragmentType.getDimSize(0), loc); + }, + loc); +} + +bool valueMayEvaluateToCore(Value value, int64_t coreId) { + if (std::optional constant = getConstantIndexValue(value)) + return *constant == coreId; + + auto affineApply = value.getDefiningOp(); + if (!affineApply) + return false; + + AffineMap map = affineApply.getAffineMap(); + if (map.getNumResults() != 1 || map.getNumDims() != 1 || map.getNumSymbols() != 0 + || affineApply.getMapOperands().size() != 1) + return false; + + auto iv = dyn_cast(affineApply.getMapOperands().front()); + if (!iv) + return false; + + auto loop = dyn_cast_or_null(iv.getOwner()->getParentOp()); + if (!loop || loop.getInductionVar() != iv) + return false; + + std::optional lower = getConstantIndexValue(loop.getLowerBound()); + std::optional upper = getConstantIndexValue(loop.getUpperBound()); + std::optional step = getConstantIndexValue(loop.getStep()); + if (!lower || !upper || !step || *step <= 0) + return false; + + for (int64_t iteration = *lower; iteration < *upper; iteration += *step) { + FailureOr evaluated = evaluateSingleResultAffineMap(map, ArrayRef{iteration}); + if (succeeded(evaluated) && *evaluated == coreId) + return true; + } + + return false; +} + +bool operationContainsReceiveFromPeer(Operation& op, int64_t localCore, int64_t peerCore, Type payloadType) { + bool found = false; + op.walk([&](SpatChannelReceiveOp receive) { + if (receive.getOutput().getType() != payloadType) + return; + if (!valueMayEvaluateToCore(receive.getTargetCoreId(), localCore)) + return; + if (!valueMayEvaluateToCore(receive.getSourceCoreId(), peerCore)) + return; + found = true; + }); + return found; +} + +LogicalResult orderLowerCoreScalarSendsAfterMatchingReceives(MaterializerState& state) { + for (MaterializedClass& materializedClass : state.classes) { + if (materializedClass.isBatch || materializedClass.cpus.empty()) + continue; + + int64_t localCore = static_cast(materializedClass.cpus.front()); + Block* body = materializedClass.body; + if (!body) + continue; + + bool changed = true; + while (changed) { + changed = false; + for (Operation& op : llvm::make_early_inc_range(*body)) { + if (&op == body->getTerminator()) + break; + + auto send = dyn_cast(&op); + if (!send) + continue; + + std::optional sourceCore = getConstantIndexValue(send.getSourceCoreId()); + std::optional targetCore = getConstantIndexValue(send.getTargetCoreId()); + if (!sourceCore || !targetCore || *sourceCore != localCore || *sourceCore >= *targetCore) + continue; + + Operation* matchingReceiveContainer = nullptr; + for (Operation* candidate = op.getNextNode(); candidate && candidate != body->getTerminator(); + candidate = candidate->getNextNode()) { + if (operationContainsReceiveFromPeer(*candidate, localCore, *targetCore, send.getInput().getType())) { + matchingReceiveContainer = candidate; + break; + } + } + + if (!matchingReceiveContainer) + continue; + + op.moveAfter(matchingReceiveContainer); + changed = true; + break; + } + } + } + + return success(); +} + +void replaceHostUses(MaterializerState& state) { + for (const auto& [oldValue, replacement] : state.hostReplacements) + replaceLiveExternalUses(oldValue, replacement, state.oldComputeOps); +} + +LogicalResult eraseOldComputeOps(MaterializerState& state) { + DenseSet seen; + for (const ComputeInstance& instance : state.schedule.dominanceOrderCompute) { + if (!seen.insert(instance.op).second) + continue; + instance.op->dropAllUses(); + instance.op->erase(); + } + return success(); +} + +} // namespace + +LogicalResult +MergeScheduleMaterializer::run(func::FuncOp func, const MergeScheduleResult& schedule, int64_t& nextChannelId) { + if (schedule.dominanceOrderCompute.empty()) + return success(); + + MaterializerState state(func, schedule, nextChannelId); + if (failed(buildMaterializationWorkStreams(state))) + return failure(); + if (failed(buildMaterializationClassesFromScheduleEquivalence(state))) + return failure(); + if (failed(verifyScheduleEquivalenceMatchesLogicalStreams(state))) + return failure(); + if (state.classes.empty()) + return success(); + + if (failed(collectHostOutputs(state))) + return failure(); + if (failed(createEmptyMaterializedOps(state))) + return failure(); + if (failed(collectProducerDestinations(state))) + return failure(); + if (failed(collectProjectedTransfers(state))) + return failure(); + + for (const ComputeInstance& instance : schedule.dominanceOrderCompute) + if (failed(materializeInstanceSlot(state, instance))) + return failure(); + + if (failed(finalizeProjectedHostOutputFragments(state))) + return failure(); + if (failed(orderLowerCoreScalarSendsAfterMatchingReceives(state))) + return failure(); + + for (MaterializedClass& materializedClass : state.classes) + if (failed(localizeAllScheduledBodyCaptures(state, materializedClass))) + return failure(); + + replaceHostUses(state); + if (failed(eraseOldComputeOps(state))) + return failure(); + + LogicalResult _ = runRegionDCE(state.rewriter, state.func.getBody()); + (void) _; + + return success(); +} + +} // namespace spatial +} // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.rej b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.rej new file mode 100644 index 0000000..3abd678 --- /dev/null +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp.rej @@ -0,0 +1,128 @@ +--- src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp 2026-06-24 18:51:29.043731129 +0000 ++++ src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp 2026-06-24 18:51:29.026726895 +0000 +@@ -4112,104 +4112,8 @@ + Value originalOutput, + Location loc); + +-FailureOr> rematerializeProjectionIndexListForBatchHostOutput( +- MaterializerState& state, +- MaterializedClass& sourceClass, +- ArrayRef values, +- IRMapping& mapper, +- Location loc) { +- SmallVector localized; +- localized.reserve(values.size()); +- for (OpFoldResult value : values) { +- FailureOr remapped = +- rematerializeIndexOpFoldResultInClass(state, sourceClass, value, loc, &mapper); +- if (failed(remapped)) +- return failure(); +- localized.push_back(*remapped); +- } +- return localized; +-} +- +-LogicalResult createProjectionAwareBatchHostInsert(MaterializerState& state, +- MaterializedClass& sourceClass, +- Value originalOutput, +- Value payload, +- Value destination, +- ArrayRef keys, +- Location loc) { +- auto originalResult = dyn_cast(originalOutput); +- if (!originalResult) +- return failure(); +- +- auto sourceBatch = dyn_cast_or_null(originalResult.getOwner()); +- if (!sourceBatch || sourceBatch.getNumResults() == 0) +- return failure(); +- +- FailureOr projection = +- getBatchResultProjectionInsert(sourceBatch, originalResult.getResultNumber()); +- if (failed(projection)) +- return failure(); +- +- auto sourceLaneArg = sourceBatch.getLaneArgument(); +- if (!sourceLaneArg) +- return failure(); +- +- auto materializedBatch = dyn_cast(sourceClass.op); +- if (!materializedBatch) +- return failure(); +- +- auto materializedLaneArg = materializedBatch.getLaneArgument(); +- if (!materializedLaneArg) +- return failure(); +- +- if (keys.size() != sourceClass.cpus.size()) +- return failure(); +- +- SmallVector logicalLanes; +- logicalLanes.reserve(keys.size()); +- for (ProducerKey key : keys) { +- if (key.instance.op != sourceBatch.getOperation() || key.resultIndex != originalResult.getResultNumber()) +- return failure(); +- logicalLanes.push_back(key.instance.laneStart); +- } +- +- IRMapping mapper; +- Value logicalLane = createIndexedIndexValue(state, +- sourceClass.op, +- ArrayRef(logicalLanes), +- *materializedLaneArg, +- loc, +- static_cast(sourceClass.cpus.size()), +- /*allowExhaustiveTiledSearch=*/false); +- mapper.map(*sourceLaneArg, logicalLane); +- +- FailureOr> offsets = +- rematerializeProjectionIndexListForBatchHostOutput( +- state, sourceClass, projection->getMixedOffsets(), mapper, loc); +- if (failed(offsets)) +- return failure(); +- FailureOr> sizes = +- rematerializeProjectionIndexListForBatchHostOutput( +- state, sourceClass, projection->getMixedSizes(), mapper, loc); +- if (failed(sizes)) +- return failure(); +- FailureOr> strides = +- rematerializeProjectionIndexListForBatchHostOutput( +- state, sourceClass, projection->getMixedStrides(), mapper, loc); +- if (failed(strides)) +- return failure(); +- +- tensor::ParallelInsertSliceOp::create( +- state.rewriter, loc, payload, destination, *offsets, *sizes, *strides); +- return success(); +-} +- + LogicalResult +-setHostOutputValue(MaterializerState& state, +- MaterializedClass& sourceClass, +- Value originalOutput, +- Value payload, +- ArrayRef keys = {}) { ++setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Value originalOutput, Value payload) { + auto resultIt = sourceClass.hostOutputToResultIndex.find(originalOutput); + if (resultIt == sourceClass.hostOutputToResultIndex.end()) + return sourceClass.op->emitError("missing host result slot for materialized output") +@@ -4253,10 +4157,6 @@ + return batch.emitOpError("expected compute_batch output block argument while materializing batch output"); + + state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); +- if (succeeded(createProjectionAwareBatchHostInsert( +- state, sourceClass, originalOutput, payload, *outputArg, keys, payload.getLoc()))) +- return success(); +- + createDim0ParallelInsertSlice(state, payload.getLoc(), payload, *outputArg, *laneArg); + return success(); + } +@@ -4276,7 +4176,7 @@ + + MaterializedClass& ownerClass = state.classes[ownerIt->second]; + if (sourceClass.id == ownerClass.id) +- return setHostOutputValue(state, ownerClass, originalOutput, payload, keys); ++ return setHostOutputValue(state, ownerClass, originalOutput, payload); + + // Keep the old deadlock-free communication discipline: only scalar-to-scalar + // host-owner forwarding is introduced here. Batch host publication remains on diff --git a/validation/tools/classification_local_image_validation.py b/validation/tools/classification_local_image_validation.py new file mode 100644 index 0000000..a440913 --- /dev/null +++ b/validation/tools/classification_local_image_validation.py @@ -0,0 +1,295 @@ +#!/usr/bin/env python3.13 + +import argparse +import math +import subprocess +import sys +from pathlib import Path + +import numpy as np +from PIL import Image, ImageDraw + +SCRIPT_DIR = Path(__file__).resolve().parent +VALIDATION_DIR = SCRIPT_DIR.parent +REPO_ROOT = VALIDATION_DIR.parent +if str(VALIDATION_DIR) not in sys.path: + sys.path.insert(0, str(VALIDATION_DIR)) + +from onnx_utils import _ONNX_TO_NP, onnx_io, write_inputs_to_memory_bin +from validate_one import ( + MODE_COMPILE_ONLY, + build_dump_ranges, + parse_pim_simulator_outputs, + run_pim_simulator, + sanitize_output_name, + validate_network, +) +from yolo_real_image_validation import save_tensor_csv + +IMAGENET_MEAN = np.asarray([0.485, 0.456, 0.406], dtype=np.float32) +IMAGENET_STD = np.asarray([0.229, 0.224, 0.225], dtype=np.float32) +DEFAULT_VGG_MODEL = VALIDATION_DIR / "networks" / "vgg16" / "depth_35" / "vgg16_depth_35.onnx" +DEFAULT_RESNET_MODEL = VALIDATION_DIR / "networks" / "resnet" / "resnet18_torchvision.onnx" + + +def resolve_default_paths(): + return { + "raptor_path": REPO_ROOT / "build_release" / "Release" / "bin" / "onnx-mlir", + "onnx_include_dir": REPO_ROOT / "onnx-mlir" / "include", + "simulator_dir": REPO_ROOT / "backend-simulators" / "pim" / "pim-simulator", + } + + +def resolve_model_path(network: str | None, model: Path | None) -> Path: + if model is not None: + return model.resolve() + if network == "resnet": + return DEFAULT_RESNET_MODEL.resolve() + if network == "vgg": + return DEFAULT_VGG_MODEL.resolve() + raise SystemExit("Pass --model or select a default with --network {resnet,vgg}.") + + +def ensure_local_artifacts(args, model_path: Path): + validate_network( + network_onnx_path=model_path, + raptor_path=args.raptor_path, + onnx_include_dir=args.onnx_include_dir, + simulator_dir=args.simulator_dir, + crossbar_size=args.crossbar_size, + crossbar_count=args.crossbar_count, + core_count=args.core_count, + command_timeout_seconds=args.command_timeout_seconds, + mode=MODE_COMPILE_ONLY, + verbose=args.verbose, + ) + + +def ensure_existing_artifacts(model_dir: Path): + required_paths = [ + model_dir / "runner" / "build" / "runner", + model_dir / "raptor" / "pim" / "config.json", + model_dir / "raptor" / "pim" / "memory.bin", + ] + missing = [str(path) for path in required_paths if not path.exists()] + if missing: + raise FileNotFoundError( + "Missing compiled local artifacts. Re-run without --skip-compile or restore these paths:\n " + + "\n ".join(missing) + ) + + +def preprocess_classification_image(image_path: Path) -> tuple[Image.Image, np.ndarray]: + image = Image.open(image_path).convert("RGB") + width, height = image.size + scale = 256.0 / min(width, height) + resized_size = ( + max(1, int(round(width * scale))), + max(1, int(round(height * scale))), + ) + resized = image.resize(resized_size, Image.Resampling.BILINEAR) + + left = (resized.width - 224) // 2 + top = (resized.height - 224) // 2 + cropped = resized.crop((left, top, left + 224, top + 224)) + + array = np.asarray(cropped, dtype=np.float32) / 255.0 + array = (array - IMAGENET_MEAN) / IMAGENET_STD + chw = np.transpose(array, (2, 0, 1)) + tensor = np.expand_dims(chw.astype(np.float32, copy=False), axis=0) + return image, tensor + + +def load_labels(labels_path: Path | None) -> list[str] | None: + if labels_path is None: + return None + labels = [line.strip() for line in labels_path.read_text().splitlines()] + return labels or None + + +def softmax(values: np.ndarray) -> np.ndarray: + shifted = values - np.max(values) + exp = np.exp(shifted) + denom = exp.sum() + if not math.isfinite(float(denom)) or denom <= 0.0: + raise RuntimeError("Softmax received non-finite output scores.") + return exp / denom + + +def decode_classification_output(output: np.ndarray, labels: list[str] | None, top_k: int): + scores = np.asarray(output, dtype=np.float64).reshape(-1) + probabilities = softmax(scores) + limit = min(top_k, probabilities.size) + top_indices = np.argsort(probabilities)[-limit:][::-1] + results = [] + for index in top_indices: + label = None + if labels is not None and 0 <= int(index) < len(labels): + label = labels[int(index)] + results.append( + { + "index": int(index), + "label": label, + "probability": float(probabilities[int(index)]), + } + ) + return results + + +def render_result_line(result) -> str: + name = result["label"] if result["label"] else f'class {result["index"]}' + return f'{name}: {result["probability"] * 100.0:.2f}%' + + +def draw_classification_panel(image: Image.Image, results, output_path: Path): + annotated = image.copy() + draw = ImageDraw.Draw(annotated) + lines = [render_result_line(result) for result in results] + if not lines: + lines = ["No predictions"] + + padding = 10 + line_gap = 4 + max_width = 0 + line_heights = [] + for line in lines: + left, top, right, bottom = draw.textbbox((0, 0), line) + max_width = max(max_width, right - left) + line_heights.append(bottom - top) + + panel_height = padding * 2 + sum(line_heights) + line_gap * (len(lines) - 1) + panel_width = padding * 2 + max_width + origin_x = 12 + origin_y = 12 + draw.rounded_rectangle( + (origin_x, origin_y, origin_x + panel_width, origin_y + panel_height), + radius=10, + fill=(0, 0, 0), + ) + + y = origin_y + padding + for line, line_height in zip(lines, line_heights): + draw.text((origin_x + padding, y), line, fill=(255, 255, 255)) + y += line_height + line_gap + + annotated.save(output_path) + + +def run_reference_and_simulator(args, model_path: Path, tensor: np.ndarray): + model_dir = model_path.parent + runner_build_dir = model_dir / "runner" / "build" + runner_path = runner_build_dir / "runner" + pim_dir = model_dir / "raptor" / "pim" + simulation_dir = model_dir / "classification_demo" / "simulation" + reference_dir = model_dir / "classification_demo" / "reference" + inputs_dir = model_dir / "classification_demo" / "inputs" + + simulation_dir.mkdir(parents=True, exist_ok=True) + reference_dir.mkdir(parents=True, exist_ok=True) + inputs_dir.mkdir(parents=True, exist_ok=True) + + input_descriptors, output_descriptors = onnx_io(model_path) + if len(input_descriptors) != 1: + raise RuntimeError(f"Expected one classification input tensor, found {len(input_descriptors)}") + if len(output_descriptors) != 1: + raise RuntimeError(f"Expected one classification output tensor, found {len(output_descriptors)}") + + input_index, _input_name, _input_dtype, input_shape = input_descriptors[0] + if list(tensor.shape) != list(input_shape): + raise RuntimeError(f"Preprocessed tensor shape {list(tensor.shape)} does not match model input {input_shape}") + + input_csv = inputs_dir / "in0.csv" + save_tensor_csv(tensor, input_csv) + + runner_cmd = [ + str(runner_path), + f"--in{input_index}-csv-file", + str(input_csv), + f"--in{input_index}-shape", + "x".join(str(dim) for dim in tensor.shape), + "--save-csv-dir", + str(reference_dir), + ] + subprocess.run(runner_cmd, cwd=runner_build_dir, check=True) + + write_inputs_to_memory_bin(pim_dir / "memory.bin", pim_dir / "config.json", [tensor]) + dump_ranges = build_dump_ranges(pim_dir / "config.json", output_descriptors) + output_bin_path = simulation_dir / "out.bin" + run_pim_simulator( + args.simulator_dir, + pim_dir, + output_bin_path, + dump_ranges, + timeout_sec=args.command_timeout_seconds, + ) + + output_index, output_name, output_dtype_code, output_shape = output_descriptors[0] + output_dtype = np.dtype(_ONNX_TO_NP[output_dtype_code]) + reference_csv = reference_dir / f"output{output_index}_{sanitize_output_name(output_name)}.csv" + reference_output = np.loadtxt(reference_csv, delimiter=",", dtype=output_dtype).reshape(output_shape) + simulator_output = parse_pim_simulator_outputs(output_bin_path, output_descriptors)[0] + return reference_output, simulator_output + + +def print_topk(title: str, results): + print(title) + for rank, result in enumerate(results, start=1): + label_text = result["label"] if result["label"] else f'class {result["index"]}' + print(f' {rank}. {label_text} ({result["probability"] * 100.0:.2f}%) [index={result["index"]}]') + + +def main(): + defaults = resolve_default_paths() + + parser = argparse.ArgumentParser(description="Run a VGG or ResNet ONNX model through the Raptor simulator and annotate the image with top classification results.") + parser.add_argument("--model", type=Path, default=None) + parser.add_argument("--network", choices=("resnet", "vgg"), default=None) + parser.add_argument("--image", type=Path, required=True) + parser.add_argument("--labels", type=Path, default=None) + parser.add_argument("--output", type=Path, required=True) + parser.add_argument("--raptor-path", type=Path, default=defaults["raptor_path"]) + parser.add_argument("--onnx-include-dir", type=Path, default=defaults["onnx_include_dir"]) + parser.add_argument("--simulator-dir", type=Path, default=defaults["simulator_dir"]) + parser.add_argument("--crossbar-size", type=int, default=2048) + parser.add_argument("--crossbar-count", type=int, default=256) + parser.add_argument("--core-count", type=int, default=1000) + parser.add_argument("--top-k", type=int, default=5) + parser.add_argument("--command-timeout-seconds", type=float, default=7200.0) + parser.add_argument("--skip-compile", action="store_true") + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + args.model = resolve_model_path(args.network, args.model) + args.image = args.image.resolve() + args.output = args.output.resolve() + args.labels = args.labels.resolve() if args.labels else None + args.raptor_path = args.raptor_path.resolve() + args.onnx_include_dir = args.onnx_include_dir.resolve() + args.simulator_dir = args.simulator_dir.resolve() + + if not args.skip_compile: + ensure_local_artifacts(args, args.model) + else: + ensure_existing_artifacts(args.model.parent) + + original_image, tensor = preprocess_classification_image(args.image) + labels = load_labels(args.labels) + reference_output, simulator_output = run_reference_and_simulator(args, args.model, tensor) + reference_results = decode_classification_output(reference_output, labels, args.top_k) + simulator_results = decode_classification_output(simulator_output, labels, args.top_k) + + print_topk("Reference top-k:", reference_results) + print_topk("Simulator top-k:", simulator_results) + + reference_scores = np.asarray(reference_output, dtype=np.float64).reshape(-1) + simulator_scores = np.asarray(simulator_output, dtype=np.float64).reshape(-1) + max_abs_diff = float(np.max(np.abs(reference_scores - simulator_scores))) + print(f"Max absolute score diff: {max_abs_diff:.6e}") + + args.output.parent.mkdir(parents=True, exist_ok=True) + draw_classification_panel(original_image, simulator_results, args.output) + print(f"Annotated image saved to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/validation/tools/compare_raptor_pimcomp.py b/validation/tools/compare_raptor_pimcomp.py new file mode 100644 index 0000000..4127410 --- /dev/null +++ b/validation/tools/compare_raptor_pimcomp.py @@ -0,0 +1,1497 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import gzip +import importlib.util +import json +import mmap +import os +import re +import shlex +import shutil +import subprocess +import sys +import time +import types +from collections import Counter +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any + +import numpy as np +import onnx + + +REPO = Path(__file__).resolve().parents[2] +VALIDATION_DIR = REPO / "validation" +sys.path.insert(0, str(VALIDATION_DIR)) + +from gen_network_runner import gen_network_runner # noqa: E402 +from onnx_utils import _ONNX_TO_NP, gen_random_inputs, onnx_io, save_inputs_to_files, write_inputs_to_memory_bin # noqa: E402 +from validate_one import build_dump_ranges, parse_pim_simulator_outputs # noqa: E402 +from raptor import compile_with_raptor # noqa: E402 + + +@dataclass +class StepRecord: + name: str + duration_sec: float + command: str + status: str = "passed" + returncode: int | None = None + error: str | None = None + output_tail: str | None = None + + +@dataclass +class CompareResult: + passed: bool + max_diffs: dict[str, float] + status: str = "done" + error: str | None = None + + +def load_pimcomp_exporter(): + path = REPO / "third_party/PIMCOMP-NN/verification/export_to_pim_simulator.py" + spec = importlib.util.spec_from_file_location("pimcomp_exporter", path) + module = importlib.util.module_from_spec(spec) + assert spec is not None and spec.loader is not None + sys.modules.setdefault("cv2", types.ModuleType("cv2")) + spec.loader.exec_module(module) + return module + + +def load_mesh_builder(): + path = REPO / "validation/pimsim-configs/generate_mesh_config.py" + spec = importlib.util.spec_from_file_location("mesh_builder", path) + module = importlib.util.module_from_spec(spec) + assert spec is not None and spec.loader is not None + spec.loader.exec_module(module) + return module + + +def shell_join(cmd: list[str]) -> str: + return shlex.join(str(arg) for arg in cmd) + + +def print_step(name: str, cmd: list[str] | None = None, cwd: Path | None = None): + print(f"\n[{name}]") + if cmd is not None: + print(f" cwd: {cwd or REPO}") + print(f" $ {shell_join(cmd)}") + + +def output_tail(output: str | bytes | None, limit: int = 4000) -> str: + if output is None: + return "" + if isinstance(output, bytes): + output = output.decode(errors="replace") + return output[-limit:] + + +def exception_message(exc: BaseException) -> str: + if isinstance(exc, subprocess.CalledProcessError): + command = shell_join([str(arg) for arg in exc.cmd]) if isinstance(exc.cmd, list) else str(exc.cmd) + tail = output_tail(exc.output) + message = f"command failed with exit code {exc.returncode}: {command}" + if tail: + message += f"\n--- output tail ---\n{tail}" + return message + if isinstance(exc, subprocess.TimeoutExpired): + command = shell_join([str(arg) for arg in exc.cmd]) if isinstance(exc.cmd, list) else str(exc.cmd) + tail = output_tail(exc.output) + message = f"command timed out after {exc.timeout} seconds: {command}" + if tail: + message += f"\n--- output tail ---\n{tail}" + return message + return f"{type(exc).__name__}: {exc}" + + +def print_failure(name: str, exc: BaseException | str) -> None: + message = exc if isinstance(exc, str) else exception_message(exc) + print(f"\n[{name} FAILED]") + for line in message.splitlines()[:20]: + print(f" {line}") + + +def run_logged( + name: str, + cmd: list[str], + *, + cwd: Path, + timeout_sec: float, + steps: list[StepRecord], +) -> str: + print_step(name, cmd, cwd) + start = time.perf_counter() + command = shell_join(cmd) + try: + proc = subprocess.run( + cmd, + cwd=cwd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + timeout=timeout_sec, + ) + except subprocess.TimeoutExpired as exc: + duration = time.perf_counter() - start + tail = output_tail(exc.output) + steps.append( + StepRecord( + name=name, + duration_sec=duration, + command=command, + status="timeout", + error=f"Timed out after {timeout_sec} seconds", + output_tail=tail or None, + ) + ) + raise + + duration = time.perf_counter() - start + if proc.returncode != 0: + tail = output_tail(proc.stdout) + steps.append( + StepRecord( + name=name, + duration_sec=duration, + command=command, + status="failed", + returncode=proc.returncode, + error=f"Exited with status {proc.returncode}", + output_tail=tail or None, + ) + ) + raise subprocess.CalledProcessError(proc.returncode, cmd, output=tail) + + steps.append(StepRecord(name=name, duration_sec=duration, command=command)) + return proc.stdout + + +def remove_tree(path: Path) -> None: + if not path.exists() and not path.is_symlink(): + return + if path.is_symlink() or path.is_file(): + path.unlink() + return + while True: + children = list(path.iterdir()) + if not children: + break + for child in children: + remove_tree(child) + path.rmdir() + + +def load_model_inputs(model_path: Path, seed: int): + model = onnx.load(model_path) + initializer_names = {init.name for init in model.graph.initializer} + initializer_values = { + init.name: onnx.numpy_helper.to_array(init) for init in model.graph.initializer + } + inputs_desc, outputs_desc = onnx_io(model_path) + runtime_desc = [desc for desc in inputs_desc if desc[1] not in initializer_names] + runtime_arrays, _ = gen_random_inputs(runtime_desc, seed=seed) + + runtime_by_name = { + desc[1]: arr for desc, arr in zip(runtime_desc, runtime_arrays) + } + arrays_in_order = [] + for _, name, elem_type, _ in inputs_desc: + if name in initializer_values: + arrays_in_order.append(initializer_values[name].astype(_ONNX_TO_NP[elem_type], copy=False)) + else: + arrays_in_order.append(runtime_by_name[name]) + runtime_only = [arr for desc, arr in zip(inputs_desc, arrays_in_order) if desc[1] not in initializer_names] + return inputs_desc, outputs_desc, arrays_in_order, runtime_only + + +def compare_simulator_outputs( + output_bin: Path, + outputs_desc: list[tuple[int, str, int, list[int]]], + reference_dir: Path, + *, + threshold: float, + rtol: float, +) -> CompareResult: + sim_arrays = parse_pim_simulator_outputs(output_bin, outputs_desc) + max_diffs: dict[str, float] = {} + passed = True + for sim_array, (idx, name, _, shape) in zip(sim_arrays, outputs_desc): + csv_name = reference_dir / f"output{idx}_{sanitize_output_name(name)}.csv" + ref = np.loadtxt(csv_name, delimiter=",", dtype=np.float32).reshape(shape) + diff = np.abs(sim_array.astype(np.float64) - ref.astype(np.float64)) + allowed = threshold + rtol * np.abs(ref.astype(np.float64)) + max_diffs[name] = float(np.max(diff)) + if not np.all(diff <= allowed): + passed = False + return CompareResult(passed=passed, max_diffs=max_diffs) + + +def sanitize_output_name(name: str) -> str: + return "".join(ch if ch.isalnum() or ch in "_.-" else "_" for ch in name[:255]) + + +def load_effective_hardware(args: argparse.Namespace) -> dict[str, int]: + config_path = args.pimcomp_dir / "config.json" + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + rows, cols = config["chip_config"]["network_config"]["layout"] + xbar_h, xbar_w = config["chip_config"]["core_config"]["matrix_config"]["xbar_size"] + hardware = { + "mesh_rows": args.mesh_rows or rows, + "mesh_cols": args.mesh_cols or cols, + "crossbar_count": args.crossbar_count or config["chip_config"]["core_config"]["matrix_config"]["xbar_array_count"], + "crossbar_size": args.crossbar_size or xbar_h, + } + if xbar_h != xbar_w: + raise ValueError(f"Only square crossbars are supported, got {xbar_h}x{xbar_w}") + hardware["core_count"] = args.core_count or hardware["mesh_rows"] * hardware["mesh_cols"] + return hardware + + +def write_pimsim_config(args: argparse.Namespace, out_dir: Path, hardware: dict[str, int]) -> Path: + mesh_builder = load_mesh_builder() + example_config = REPO / "backend-simulators/pim/pimsim-nn/example/config/latency_config.json" + with open(example_config, "r", encoding="utf-8") as f: + config = json.load(f) + config["chip_config"]["core_config"]["matrix_config"]["xbar_array_count"] = hardware["crossbar_count"] + config["chip_config"]["core_config"]["matrix_config"]["xbar_size"] = [ + hardware["crossbar_size"], + hardware["crossbar_size"], + ] + config["chip_config"]["network_config"]["layout"] = [ + hardware["mesh_rows"], + hardware["mesh_cols"], + ] + config["chip_config"]["network_config"]["net_config_file_path"] = f"network_mesh_{hardware['core_count']}.json" + config["chip_config"]["core_cnt"] = hardware["core_count"] + config["sim_config"]["sim_mode"] = 1 if args.pimsim_mode == "latency" else 0 + config["sim_config"]["sim_time"] = args.pimsim_time_ms + out_dir.mkdir(parents=True, exist_ok=True) + config_path = out_dir / f"{args.pimsim_mode}_config.json" + network_path = out_dir / f"network_mesh_{hardware['core_count']}.json" + with open(config_path, "w", encoding="utf-8") as f: + json.dump(config, f, indent=2) + f.write("\n") + with open(network_path, "w", encoding="utf-8") as f: + json.dump( + mesh_builder.build_network( + hardware["core_count"], + (hardware["mesh_rows"], hardware["mesh_cols"]), + ), + f, + separators=(",", ":"), + ) + f.write("\n") + return config_path + + +def compile_reference( + args: argparse.Namespace, + model_path: Path, + work_dir: Path, + steps: list[StepRecord], +) -> tuple[Path, Path, Path]: + raptor_dir = work_dir / "reference" + runner_dir = work_dir / "runner" + build_dir = runner_dir / "build" + raptor_dir.mkdir(parents=True, exist_ok=True) + build_dir.mkdir(parents=True, exist_ok=True) + stem = model_path.stem + onnx_ir_base = raptor_dir / stem + runner_base = runner_dir / stem + + run_logged( + "Reference Emit ONNX IR", + [str(args.raptor_path), str(model_path), "-o", str(onnx_ir_base), "--EmitONNXIR"], + cwd=REPO, + timeout_sec=args.timeout_seconds, + steps=steps, + ) + run_logged( + "Reference Native Compile", + [str(args.raptor_path), "-O3", str(model_path), "-o", str(runner_base)], + cwd=REPO, + timeout_sec=args.timeout_seconds, + steps=steps, + ) + network_so = runner_base.with_suffix(".so") + network_mlir = onnx_ir_base.with_suffix(".onnx.mlir") + + print_step("Generate Runner Source") + gen_network_runner(model_path, network_so, args.onnx_include_dir, out=runner_dir / "runner.c", verbose=False) + + run_logged( + "Configure Runner", + ["cmake", str(runner_dir), "-DCMAKE_BUILD_TYPE=Release", "-DCMAKE_C_FLAGS_RELEASE=-O3"], + cwd=build_dir, + timeout_sec=args.timeout_seconds, + steps=steps, + ) + run_logged( + "Build Runner", + ["cmake", "--build", ".", "-j"], + cwd=build_dir, + timeout_sec=args.timeout_seconds, + steps=steps, + ) + return network_mlir, network_so, build_dir / "runner" + + +def generate_reference_outputs( + runner_path: Path, + runner_build_dir: Path, + model_path: Path, + arrays_in_order: list[np.ndarray], + steps: list[StepRecord], + args: argparse.Namespace, + out_dir: Path, +) -> Path: + inputs_dir = out_dir / "inputs" + reference_dir = out_dir / "reference_outputs" + inputs_dir.mkdir(parents=True, exist_ok=True) + reference_dir.mkdir(parents=True, exist_ok=True) + flags, _ = save_inputs_to_files(model_path, arrays_in_order, inputs_dir) + run_logged( + "Run Reference", + [str(runner_path), *flags, "--save-csv-dir", str(reference_dir)], + cwd=runner_build_dir, + timeout_sec=args.timeout_seconds, + steps=steps, + ) + return reference_dir + + +def compile_raptor_target( + model_mlir: Path, + out_dir: Path, + hardware: dict[str, int], + args: argparse.Namespace, + steps: list[StepRecord], +) -> tuple[Path, dict[str, float]]: + out_dir.mkdir(parents=True, exist_ok=True) + cmd = [ + str(args.raptor_path), + str(model_mlir), + "-o", + str(out_dir / "model"), + "--maccel=PIM", + "--EmitPimCodegen", + f"--crossbar-size={hardware['crossbar_size']}", + f"--crossbar-count={hardware['crossbar_count']}", + f"--core-count={hardware['core_count']}", + "--pim-emit-json", + *args.raptor_extra_arg, + ] + print_step("Compile Raptor PIM", cmd, REPO) + start = time.perf_counter() + command = shell_join(cmd) + raptor_extra_args = ["--pim-emit-json", *args.raptor_extra_arg] + try: + timings = compile_with_raptor( + model_mlir, + args.raptor_path, + out_dir / "model", + hardware["crossbar_size"], + hardware["crossbar_count"], + core_count=hardware["core_count"], + raptor_extra_args=raptor_extra_args, + cwd=out_dir, + verbose=args.verbose_raptor_compile, + timeout_sec=args.timeout_seconds, + ) + except Exception as exc: + steps.append( + StepRecord( + name="Compile Raptor PIM", + duration_sec=time.perf_counter() - start, + command=command, + status="failed", + error=exception_message(exc), + ) + ) + raise + + steps.append( + StepRecord( + name="Compile Raptor PIM", + duration_sec=time.perf_counter() - start, + command=command, + ) + ) + return out_dir / "pim", timings + + +def run_rust_validation( + label: str, + pim_dir: Path, + config_path: Path, + outputs_desc: list[tuple[int, str, int, list[int]]], + reference_dir: Path, + steps: list[StepRecord], + args: argparse.Namespace, +) -> CompareResult: + output_bin = pim_dir.parent / "semantic_validation" / "out.bin" + dump_ranges = build_dump_ranges(config_path, outputs_desc) + cmd = [ + "cargo", + "run", + "--no-default-features", + "--release", + "--package", + "pim-simulator", + "--bin", + "pim-simulator", + "--", + "-f", + str(pim_dir), + "-o", + str(output_bin), + "-d", + dump_ranges, + ] + simulation_dir = pim_dir.parent / "semantic_validation" + simulation_dir.mkdir(parents=True, exist_ok=True) + run_logged( + label, + cmd, + cwd=args.pim_simulator_dir, + timeout_sec=args.timeout_seconds, + steps=steps, + ) + return compare_simulator_outputs( + output_bin, + outputs_desc, + reference_dir, + threshold=args.threshold, + rtol=args.rtol, + ) + + +def copy_pimcomp_outputs(args: argparse.Namespace, out_dir: Path): + out_dir.mkdir(parents=True, exist_ok=True) + for name in ("SimulationInfo.gz", "VerificationInfo.json", "MappingResult.txt"): + shutil.copy2(args.pimcomp_dir / "output" / name, out_dir / name) + + +def compile_pimcomp( + args: argparse.Namespace, + model_path: Path, + out_dir: Path, + steps: list[StepRecord], +) -> tuple[Path, Path]: + out_dir.mkdir(parents=True, exist_ok=True) + model_name = f"compare_{model_path.stem}" + frontend_json = args.pimcomp_dir / "models/JSON" / f"{model_name}.json" + frontend_cmd = [ + "python3", + "frontend.py", + "--model_path", + str(model_path), + "--save_path", + str(frontend_json), + ] + run_logged( + "PIMCOMP Frontend", + frontend_cmd, + cwd=args.pimcomp_dir / "frontend", + timeout_sec=args.timeout_seconds, + steps=steps, + ) + backend_cmd = [ + str(args.pimcomp_dir / "build" / "PIMCOMP-NN"), + f"-m={model_name}", + "-p=batch", + "-v=YES", + "-s=YES", + ] + run_logged( + "PIMCOMP Backend", + backend_cmd, + cwd=args.pimcomp_dir / "build", + timeout_sec=args.timeout_seconds, + steps=steps, + ) + copy_pimcomp_outputs(args, out_dir) + return out_dir / "VerificationInfo.json", out_dir / "SimulationInfo.gz" + + +def export_pimcomp_for_pimsim_nn(simulation_info: Path, output_dir: Path) -> Path: + if output_dir.exists(): + remove_tree(output_dir) + with gzip.open(simulation_info, "rt", encoding="utf-8") as f: + sim_info = json.load(f) + + output_dir.mkdir(parents=True, exist_ok=True) + sim_config = sim_info["config"] + present_core_indices = sorted( + int(key[4:]) for key, value in sim_info.items() if key.startswith("core") and isinstance(value, list) and value + ) + if not present_core_indices: + raise ValueError("PIMCOMP SimulationInfo.gz does not contain any non-empty core instruction streams") + expected_core_indices = list(range(present_core_indices[-1] + 1)) + if present_core_indices != expected_core_indices: + raise ValueError(f"PIMCOMP core numbering is not contiguous: {present_core_indices}") + + config = { + "core_cnt": len(present_core_indices), + "xbar_size": sim_config["xbar_size"], + "xbar_array_count": sim_config["xbar_array_count"], + "cell_precision": sim_config["cell_precision"], + "adc_count": sim_config["adc_count"], + "array_group_map": {}, + } + for core_idx in present_core_indices: + core_name = f"core{core_idx}" + config["array_group_map"][core_name] = sim_config["array_group_map"].get(core_name, []) + + with open(output_dir / "config.json", "w", encoding="utf-8") as f: + json.dump(config, f, separators=(",", ":")) + f.write("\n") + + for core_idx in present_core_indices: + core_key = f"core{core_idx}" + instructions = sim_info[core_key] + with open(output_dir / f"core_{core_idx}.json", "w", encoding="utf-8") as f: + json.dump(instructions, f, separators=(",", ":")) + f.write("\n") + return output_dir + + +def flatten_pimcomp_input(array: np.ndarray) -> np.ndarray: + tensor = array.astype(np.float32, copy=False) + if tensor.ndim == 4: + tensor = tensor.transpose((0, 2, 3, 1)) + return tensor.reshape(-1) + + +def export_pimcomp_for_rust( + model_path: Path, + verification_info: Path, + simulation_info: Path, + runtime_inputs: list[np.ndarray], + output_dir: Path, +) -> Path: + if len(runtime_inputs) != 1: + raise ValueError("PIMCOMP export currently requires exactly one runtime input tensor") + if output_dir.exists(): + remove_tree(output_dir) + exporter = load_pimcomp_exporter() + with open(verification_info, "r", encoding="utf-8") as f: + final_info = json.load(f) + with gzip.open(simulation_info, "rt", encoding="utf-8") as f: + sim_info = json.load(f) + + onnx_model, weights, gemm_weights, output_to_weight, output_to_bias = exporter.load_model_info( + model_path, final_info + ) + input_tensor = flatten_pimcomp_input(runtime_inputs[0]) + node_list = final_info["node_list"] + max_output = exporter.max_output_element_num(node_list) + local_group_map = exporter.map_local_groups(final_info, sim_info) + + output_dir.mkdir(parents=True, exist_ok=True) + weights_dir = output_dir / "weights" + weights_dir.mkdir(parents=True, exist_ok=True) + + input_addr = 0 + cursor = exporter.byte_offset(len(input_tensor)) + bias_addrs: dict[str, int] = {} + for node_name, bias_name in output_to_bias.items(): + bias = weights[bias_name].astype(np.float32).flatten() + bias_addrs[node_name] = cursor + cursor += exporter.byte_offset(len(bias)) + + lldi_addrs: dict[tuple[bytes, int], int] = {} + for core_idx in range(sim_info["config"]["core_cnt"]): + for inst in sim_info.get(f"core{core_idx}", []) or []: + if inst["op"] != "lldi": + continue + key = (exporter.float32_bytes(inst["imm"]), inst["len"]) + if key not in lldi_addrs: + lldi_addrs[key] = cursor + cursor += exporter.byte_offset(inst["len"]) + + output_base = (cursor + 255) & ~255 + memory_size = output_base + exporter.byte_offset(max_output * len(node_list)) + memory = bytearray(memory_size) + memory[input_addr : input_addr + input_tensor.nbytes] = input_tensor.tobytes() + for node_name, bias_name in output_to_bias.items(): + bias = weights[bias_name].astype(np.float32).flatten() + start = bias_addrs[node_name] + memory[start : start + bias.nbytes] = bias.tobytes() + for (value_bytes, element_num), start in lldi_addrs.items(): + value = np.frombuffer(value_bytes, dtype=np.float32)[0] + blob = np.full(element_num, value, dtype=np.float32) + memory[start : start + blob.nbytes] = blob.tobytes() + + config = { + "core_cnt": sim_info["config"]["core_cnt"], + "xbar_size": sim_info["config"]["xbar_size"], + "xbar_array_count": sim_info["config"]["xbar_array_count"], + "cell_precision": sim_info["config"]["cell_precision"], + "adc_count": sim_info["config"]["adc_count"], + "array_group_map": {}, + "inputs_addresses": [input_addr], + "outputs_addresses": [], + } + output_name_to_node = {node["name"]: node for node in node_list} + for graph_output in onnx_model.graph.output: + node = output_name_to_node[graph_output.name] + config["outputs_addresses"].append(output_base + exporter.byte_offset(node["new_node_index"] * max_output)) + + ag_info = final_info["AG_info"] + weight_counter = 0 + xbar_size = int(sim_info["config"]["xbar_size"][0]) + for core_idx in range(config["core_cnt"]): + core_name = f"core{core_idx}" + core_dir = output_dir / f"core_{core_idx}" + core_dir.mkdir(parents=True, exist_ok=True) + local_to_global = local_group_map.get(core_idx, {}) + ag_counts = sim_info["config"]["array_group_map"].get(core_name, []) + group_prefix = [] + total_crossbars = 0 + for count in ag_counts: + group_prefix.append(total_crossbars) + total_crossbars += count + config["array_group_map"][core_name] = list(range(total_crossbars)) + + for local_group, global_ag in sorted(local_to_global.items()): + info = ag_info[global_ag] + weight_name = output_to_weight[info["node_name"]] + matrix = gemm_weights[weight_name] + row_slice = slice(info["height_start"], info["height_end"] + 1) + first_physical = group_prefix[local_group] + for crossbar_idx, crossbar in enumerate(info["crossbar"]): + col_slice = slice(crossbar["width_start"], crossbar["width_end"] + 1) + tile = np.zeros((xbar_size, col_slice.stop - col_slice.start), dtype=np.float32) + tile_rows = matrix[row_slice, col_slice].astype(np.float32) + tile[: tile_rows.shape[0], :] = tile_rows + weight_path = weights_dir / f"crossbar_{weight_counter}.bin" + weight_path.write_bytes(tile.tobytes(order="C")) + os.symlink(weight_path.resolve(), core_dir / f"crossbar_{first_physical + crossbar_idx}.bin") + weight_counter += 1 + + instructions = [] + last_sldi_by_rd: dict[int, int] = {} + ver_ops = exporter.filtered_verification_ops(final_info, core_idx) + ver_index = 0 + for sim_inst in sim_info.get(core_name, []) or []: + op = sim_inst["op"] + if op == "setbw": + instructions.append(sim_inst) + continue + if op == "sldi": + translated = {"op": "sldi", "rd": sim_inst["rd"], "imm": exporter.byte_offset(sim_inst["imm"])} + instructions.append(translated) + last_sldi_by_rd[sim_inst["rd"]] = len(instructions) - 1 + continue + if ver_index >= len(ver_ops): + raise RuntimeError(f"core{core_idx}: simulation op {op} has no matching verification op") + ver_inst = ver_ops[ver_index] + ver_index += 1 + ver_op = ver_inst["operation"].lower() + if ver_op != op: + raise RuntimeError( + f"core{core_idx}: simulation/verification op mismatch {op} vs {ver_op} at {ver_index - 1}" + ) + if op == "ld": + if ver_inst["stage"] == "INPUT": + src = input_addr + exporter.byte_offset(ver_inst["source_offset"]) + elif ver_inst["stage"] == "BIAS": + src = bias_addrs[node_list[ver_inst["node_index"]]["name"]] + exporter.byte_offset(ver_inst["source_offset"]) + else: + raise RuntimeError(f"Unsupported LD stage {ver_inst['stage']}") + instructions[last_sldi_by_rd[sim_inst["rs1"]]]["imm"] = src + translated = dict(sim_inst) + translated["size"] = exporter.byte_offset(sim_inst["size"]) + instructions.append(translated) + elif op == "st": + dst = output_base + exporter.byte_offset( + ver_inst["node_index"] * max_output + ver_inst["destination_offset"] + ) + instructions[last_sldi_by_rd[sim_inst["rd"]]]["imm"] = dst + translated = dict(sim_inst) + translated["size"] = exporter.byte_offset(sim_inst["size"]) + instructions.append(translated) + elif op == "lldi": + key = (exporter.float32_bytes(sim_inst["imm"]), sim_inst["len"]) + src = lldi_addrs[key] + temp_rd = 1 if sim_inst["rd"] == 0 else 0 + instructions.append({"op": "sldi", "rd": temp_rd, "imm": src}) + instructions.append( + { + "op": "ld", + "rd": sim_inst["rd"], + "rs1": temp_rd, + "size": exporter.byte_offset(sim_inst["len"]), + "offset": sim_inst["offset"], + } + ) + elif op in ("lmv", "vvadd", "vvmul", "vvmax", "vrelu"): + translated = dict(sim_inst) + translated["len"] = exporter.byte_offset(sim_inst["len"]) + instructions.append(translated) + elif op in ("send", "recv"): + translated = dict(sim_inst) + translated["size"] = exporter.byte_offset(sim_inst["size"]) + instructions.append(translated) + elif op == "mvmul": + local_group = sim_inst["group"] + global_ag = local_to_global[local_group] + first_physical = group_prefix[local_group] + widths = [ + crossbar["width_end"] - crossbar["width_start"] + 1 + for crossbar in ag_info[global_ag]["crossbar"] + ] + dst = instructions[last_sldi_by_rd[sim_inst["rd"]]]["imm"] + src = instructions[last_sldi_by_rd[sim_inst["rs1"]]]["imm"] + out_offset = 0 + for idx, width in enumerate(widths): + instructions.append({"op": "sldi", "rd": sim_inst["rd"], "imm": dst + exporter.byte_offset(out_offset)}) + instructions.append({"op": "sldi", "rd": sim_inst["rs1"], "imm": src}) + translated = dict(sim_inst) + translated["group"] = first_physical + idx + instructions.append(translated) + out_offset += width + else: + raise RuntimeError(f"Unsupported PIMCOMP op {op}") + + with open(output_dir / f"core_{core_idx}.json", "w", encoding="utf-8") as f: + json.dump(instructions, f, separators=(",", ":")) + f.write("\n") + + with open(output_dir / "config.json", "w", encoding="utf-8") as f: + json.dump(config, f, separators=(",", ":")) + f.write("\n") + (output_dir / "memory.bin").write_bytes(memory) + return output_dir + + +def parse_pimsim_nn_report(output: str) -> dict[str, float | int | str]: + patterns = { + "output_count": r"output count:\s+([0-9]+)\s+samples", + "throughput": r"throughput:\s+([0-9.]+)\s+samples/s", + "average_latency_ms": r"average latency:\s+([0-9.eE+-]+)\s+ms", + "latency_ms": r"latency:\s+([0-9.eE+-]+)\s+ms", + "average_power_mw": r"average power:\s+([0-9.eE+-]+)\s+mW", + "average_energy_pj": r"average energy:\s+([0-9.eE+-]+)\s+pJ/it", + } + result: dict[str, float | int | str] = {"raw_output": output} + for key, pattern in patterns.items(): + match = re.search(pattern, output) + if match: + value = match.group(1) + result[key] = int(value) if key == "output_count" else float(value) + return result + + +def run_pimsim_nn( + label: str, + inst_path: Path, + config_path: Path, + single_file: bool, + steps: list[StepRecord], + args: argparse.Namespace, +) -> dict[str, Any]: + cmd = [ + str(args.pimsim_nn_build_dir / "ChipTest"), + str(inst_path), + str(config_path), + "true" if single_file else "false", + ] + output = run_logged( + label, + cmd, + cwd=args.pimsim_nn_build_dir, + timeout_sec=args.timeout_seconds * 10.0, + steps=steps, + ) + return parse_pimsim_nn_report(output) + + +def parse_raptor_instructions(pim_dir: Path) -> dict[str, Any]: + op_re = re.compile(br'"op":"([^"]+)"') + counts = Counter() + per_core = [] + for path in sorted(pim_dir.glob("core_*.json"), key=lambda p: int(p.stem.split("_")[1])): + with path.open("rb") as f: + mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + core_counts = Counter(m.group(1).decode() for m in op_re.finditer(mm)) + mm.close() + total = sum(core_counts.values()) + counts.update(core_counts) + per_core.append( + { + "core": path.stem, + "total": total, + "send": core_counts.get("send", 0), + "recv": core_counts.get("recv", 0), + "mvmul": core_counts.get("mvmul", 0), + } + ) + return { + "active_cores": sum(1 for entry in per_core if entry["total"]), + "total_instructions": int(sum(counts.values())), + "op_counts": dict(counts), + "top_cores_by_total": sorted(per_core, key=lambda entry: entry["total"], reverse=True)[:10], + "top_cores_by_send": sorted(per_core, key=lambda entry: entry["send"], reverse=True)[:10], + "top_cores_by_recv": sorted(per_core, key=lambda entry: entry["recv"], reverse=True)[:10], + } + + +def parse_pimcomp_instructions(simulation_info: Path) -> dict[str, Any]: + with gzip.open(simulation_info, "rt", encoding="utf-8") as f: + data = json.load(f) + per_core = [] + counts = Counter() + for key in sorted((name for name in data if name.startswith("core")), key=lambda name: int(name[4:])): + insts = data[key] + core_counts = Counter((inst.get("operation") or inst.get("op") or "unknown").lower() for inst in insts) + counts.update(core_counts) + per_core.append( + { + "core": key, + "total": int(sum(core_counts.values())), + "send": core_counts.get("send", 0), + "recv": core_counts.get("recv", 0), + "mvmul": core_counts.get("mvmul", 0), + } + ) + return { + "active_cores": sum(1 for entry in per_core if entry["total"]), + "total_instructions": int(sum(counts.values())), + "op_counts": dict(counts), + "top_cores_by_total": sorted(per_core, key=lambda entry: entry["total"], reverse=True)[:10], + "top_cores_by_send": sorted(per_core, key=lambda entry: entry["send"], reverse=True)[:10], + "top_cores_by_recv": sorted(per_core, key=lambda entry: entry["recv"], reverse=True)[:10], + } + + +def format_op_table(counts: dict[str, int], total: int) -> list[str]: + if total <= 0: + return ["| n/a | 0 | n/a |"] + rows = [] + for op, count in sorted(counts.items(), key=lambda item: item[1], reverse=True): + rows.append(f"| `{op}` | {count} | {100.0 * count / total:.2f}% |") + return rows + + +def validation_status(result: CompareResult) -> str: + if result.status == "done": + return "PASS" if result.passed else "FAIL" + return result.status.upper() + + +def skipped_validation(reason: str) -> CompareResult: + return CompareResult(passed=False, max_diffs={}, status="skipped", error=reason) + + +def failed_validation(error: BaseException | str) -> CompareResult: + message = error if isinstance(error, str) else exception_message(error) + return CompareResult(passed=False, max_diffs={}, status="failed", error=message) + + +def skipped_perf(reason: str) -> dict[str, Any]: + return {"skipped": True, "reason": reason} + + +def failed_perf(error: BaseException | str) -> dict[str, Any]: + message = error if isinstance(error, str) else exception_message(error) + return {"error": message} + + +def perf_status(perf: dict[str, Any]) -> str: + if perf.get("skipped"): + return "SKIPPED" + if perf.get("error"): + return "FAILED" + return "DONE" + + +def perf_value(perf: dict[str, Any], key: str) -> Any: + return perf[key] if key in perf else "n/a" + + +def empty_instruction_summary(reason: str | None = None, error: str | None = None) -> dict[str, Any]: + result: dict[str, Any] = { + "active_cores": 0, + "total_instructions": 0, + "op_counts": {}, + "top_cores_by_total": [], + "top_cores_by_send": [], + "top_cores_by_recv": [], + } + if reason is not None: + result["skipped"] = True + result["reason"] = reason + if error is not None: + result["error"] = error + return result + + +def optional_path(path: Path | None) -> str | None: + return str(path) if path is not None else None + + +def record_failure(failures: list[dict[str, str]], stage: str, exc: BaseException | str) -> None: + message = exc if isinstance(exc, str) else exception_message(exc) + failures.append({"stage": stage, "error": message}) + print_failure(stage, message) + + +def try_stage( + failures: list[dict[str, str]], + stage: str, + func, + *args, + **kwargs, +): + try: + return func(*args, **kwargs) + except Exception as exc: + record_failure(failures, stage, exc) + return None + + +def try_stage_success( + failures: list[dict[str, str]], + stage: str, + func, + *args, + **kwargs, +) -> bool: + try: + func(*args, **kwargs) + return True + except Exception as exc: + record_failure(failures, stage, exc) + return False + + +def write_report( + report_path: Path, + *, + model_path: Path, + hardware: dict[str, int], + steps: list[StepRecord], + failures: list[dict[str, str]], + raptor_validation: CompareResult, + pimcomp_validation: CompareResult, + raptor_perf: dict[str, Any], + pimcomp_perf: dict[str, Any], + raptor_instr: dict[str, Any], + pimcomp_instr: dict[str, Any], + raptor_pass_timings: dict[str, float], + pimsim_mode: str, +): + lines = [ + "# Raptor vs PIMCOMP Comparison Report", + "", + f"- Model: `{model_path}`", + f"- Hardware: `{hardware.get('core_count', 'n/a')} cores`, `{hardware.get('crossbar_count', 'n/a')} xbars/core`, `{hardware.get('crossbar_size', 'n/a')}x{hardware.get('crossbar_size', 'n/a')}` crossbars, mesh `{hardware.get('mesh_rows', 'n/a')}x{hardware.get('mesh_cols', 'n/a')}`", + "", + ] + + if failures or any(step.status != "passed" for step in steps): + lines.extend( + [ + "## Failures / Skipped Work", + "", + "The script did not abort. The failed stage was recorded and any dependent stage was skipped when its inputs were not available.", + "", + ] + ) + if failures: + lines.extend(["| Stage | Error |", "|---|---|"]) + for failure in failures: + error = failure["error"].replace("\n", "
") + lines.append(f"| {failure['stage']} | {error} |") + lines.append("") + + lines.extend( + [ + "## Semantic Validation", + "", + f"- Raptor via `pim-simulator`: `{validation_status(raptor_validation)}`", + f"- PIMCOMP via exported `pim-simulator`: `{validation_status(pimcomp_validation)}`", + ] + ) + if raptor_validation.error: + lines.append(f"- Raptor validation note: `{raptor_validation.error.splitlines()[0]}`") + if pimcomp_validation.error: + lines.append(f"- PIMCOMP validation note: `{pimcomp_validation.error.splitlines()[0]}`") + + lines.extend(["", "### Max Output Differences", ""]) + diff_names = sorted(set(raptor_validation.max_diffs) | set(pimcomp_validation.max_diffs)) + if diff_names: + lines.extend(["| Output | Raptor max diff | PIMCOMP max diff |", "|---|---:|---:|"]) + for name in diff_names: + lines.append( + f"| `{name}` | {raptor_validation.max_diffs.get(name, float('nan')):.6e} | " + f"{pimcomp_validation.max_diffs.get(name, float('nan')):.6e} |" + ) + else: + lines.append("No output differences are available because validation did not run or failed before comparison.") + + lines.extend( + [ + "", + "## pimsim-nn Performance", + "", + f"- Mode: `{pimsim_mode}`", + "", + ] + ) + if pimsim_mode == "throughput": + lines.extend( + [ + "| Compiler | Status | Throughput (samples/s) | Avg latency (ms) | Avg power (mW) | Avg energy (pJ/it) | Output count |", + "|---|---|---:|---:|---:|---:|---:|", + f"| Raptor | {perf_status(raptor_perf)} | {perf_value(raptor_perf, 'throughput')} | {perf_value(raptor_perf, 'average_latency_ms')} | " + f"{perf_value(raptor_perf, 'average_power_mw')} | {perf_value(raptor_perf, 'average_energy_pj')} | {perf_value(raptor_perf, 'output_count')} |", + f"| PIMCOMP | {perf_status(pimcomp_perf)} | {perf_value(pimcomp_perf, 'throughput')} | {perf_value(pimcomp_perf, 'average_latency_ms')} | " + f"{perf_value(pimcomp_perf, 'average_power_mw')} | {perf_value(pimcomp_perf, 'average_energy_pj')} | {perf_value(pimcomp_perf, 'output_count')} |", + "", + ] + ) + else: + lines.extend( + [ + "| Compiler | Status | Latency (ms) | Avg power (mW) | Avg energy (pJ) |", + "|---|---|---:|---:|---:|", + f"| Raptor | {perf_status(raptor_perf)} | {perf_value(raptor_perf, 'latency_ms')} | " + f"{perf_value(raptor_perf, 'average_power_mw')} | {perf_value(raptor_perf, 'average_energy_pj')} |", + f"| PIMCOMP | {perf_status(pimcomp_perf)} | {perf_value(pimcomp_perf, 'latency_ms')} | " + f"{perf_value(pimcomp_perf, 'average_power_mw')} | {perf_value(pimcomp_perf, 'average_energy_pj')} |", + "", + ] + ) + if raptor_perf.get("reason") or raptor_perf.get("error"): + lines.append(f"- Raptor pimsim-nn note: `{(raptor_perf.get('reason') or raptor_perf.get('error')).splitlines()[0]}`") + if pimcomp_perf.get("reason") or pimcomp_perf.get("error"): + lines.append(f"- PIMCOMP pimsim-nn note: `{(pimcomp_perf.get('reason') or pimcomp_perf.get('error')).splitlines()[0]}`") + if lines[-1] != "": + lines.append("") + + lines.extend( + [ + "## Instruction Summary", + "", + "| Compiler | Status | Active cores | Total instructions | Sends | Receives | MVMUL |", + "|---|---|---:|---:|---:|---:|---:|", + f"| Raptor | {'FAILED' if raptor_instr.get('error') else 'SKIPPED' if raptor_instr.get('skipped') else 'DONE'} | {raptor_instr.get('active_cores', 0)} | {raptor_instr.get('total_instructions', 0)} | {raptor_instr.get('op_counts', {}).get('send', 0)} | {raptor_instr.get('op_counts', {}).get('recv', 0)} | {raptor_instr.get('op_counts', {}).get('mvmul', 0)} |", + f"| PIMCOMP | {'FAILED' if pimcomp_instr.get('error') else 'SKIPPED' if pimcomp_instr.get('skipped') else 'DONE'} | {pimcomp_instr.get('active_cores', 0)} | {pimcomp_instr.get('total_instructions', 0)} | {pimcomp_instr.get('op_counts', {}).get('send', 0)} | {pimcomp_instr.get('op_counts', {}).get('recv', 0)} | {pimcomp_instr.get('op_counts', {}).get('mvmul', 0)} |", + "", + "### Raptor Op Distribution", + "", + "| Op | Count | Share |", + "|---|---:|---:|", + *format_op_table(raptor_instr.get("op_counts", {}), raptor_instr.get("total_instructions", 0)), + "", + "### PIMCOMP Op Distribution", + "", + "| Op | Count | Share |", + "|---|---:|---:|", + *format_op_table(pimcomp_instr.get("op_counts", {}), pimcomp_instr.get("total_instructions", 0)), + "", + "## Step Timings", + "", + "| Step | Status | Duration (s) | Return code |", + "|---|---|---:|---:|", + ] + ) + for step in steps: + lines.append( + f"| {step.name} | {step.status.upper()} | {step.duration_sec:.3f} | " + f"{step.returncode if step.returncode is not None else ''} |" + ) + failed_steps = [step for step in steps if step.status != "passed"] + if failed_steps: + lines.extend(["", "### Failed Step Details", ""]) + for step in failed_steps: + lines.extend( + [ + f"#### {step.name}", + "", + f"- Command: `{step.command}`", + f"- Error: `{step.error or 'n/a'}`", + ] + ) + if step.output_tail: + lines.extend(["", "```text", step.output_tail, "```"]) + lines.append("") + + if raptor_pass_timings: + lines.extend(["", "## Raptor Pass Timings", "", "| Pass | Duration (s) |", "|---|---:|"]) + for name, duration in raptor_pass_timings.items(): + lines.append(f"| {name} | {duration:.4f} |") + report_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", required=True, type=Path) + parser.add_argument("--out-dir", required=True, type=Path) + parser.add_argument("--raptor-path", default=REPO / "build_release/Release/bin/onnx-mlir", type=Path) + parser.add_argument("--onnx-include-dir", default=REPO / "onnx-mlir/include", type=Path) + parser.add_argument("--pimcomp-dir", default=REPO / "third_party/PIMCOMP-NN", type=Path) + parser.add_argument("--pim-simulator-dir", default=REPO / "backend-simulators/pim/pim-simulator", type=Path) + parser.add_argument("--pimsim-nn-build-dir", default=REPO / "backend-simulators/pim/pimsim-nn/build", type=Path) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--threshold", type=float, default=1e-3) + parser.add_argument("--rtol", type=float, default=1e-5) + parser.add_argument("--timeout-seconds", type=float, default=3600.0) + parser.add_argument("--core-count", type=int) + parser.add_argument("--crossbar-count", type=int) + parser.add_argument("--crossbar-size", type=int) + parser.add_argument("--mesh-rows", type=int) + parser.add_argument("--mesh-cols", type=int) + parser.add_argument("--pimsim-time-ms", type=int, default=1000) + parser.add_argument("--pimsim-mode", choices=["latency", "throughput"], default="latency") + parser.add_argument("--skip-pimsim-nn", action="store_true") + parser.add_argument("--verbose-raptor-compile", action="store_true") + parser.add_argument("--raptor-extra-arg", action="append", default=[]) + parser.add_argument( + "--fail-on-error", + action="store_true", + help="Return a non-zero process status after writing the reports if any compilation/run stage failed.", + ) + args = parser.parse_args() + + model_path = args.model.resolve() + out_dir = args.out_dir.resolve() + out_dir.mkdir(parents=True, exist_ok=True) + + failures: list[dict[str, str]] = [] + steps: list[StepRecord] = [] + hardware: dict[str, int] = { + "mesh_rows": 0, + "mesh_cols": 0, + "crossbar_count": 0, + "crossbar_size": 0, + "core_count": 0, + } + inputs_desc: list[tuple[int, str, int, list[int]]] = [] + outputs_desc: list[tuple[int, str, int, list[int]]] = [] + arrays_in_order: list[np.ndarray] = [] + runtime_inputs: list[np.ndarray] = [] + + network_mlir: Path | None = None + runner_path: Path | None = None + reference_dir: Path | None = None + raptor_pim_dir: Path | None = None + raptor_pass_timings: dict[str, float] = {} + verification_info: Path | None = None + simulation_info: Path | None = None + pimcomp_export_dir: Path | None = None + pimsim_config: Path | None = None + + raptor_validation = skipped_validation("Raptor validation did not run") + pimcomp_validation = skipped_validation("PIMCOMP validation did not run") + raptor_perf: dict[str, Any] = skipped_perf("pimsim-nn Raptor did not run") + pimcomp_perf: dict[str, Any] = skipped_perf("pimsim-nn PIMCOMP did not run") + raptor_instr: dict[str, Any] = empty_instruction_summary("Raptor instruction parsing did not run") + pimcomp_instr: dict[str, Any] = empty_instruction_summary("PIMCOMP instruction parsing did not run") + + loaded_hardware = try_stage(failures, "Load hardware configuration", load_effective_hardware, args) + if loaded_hardware is not None: + hardware = loaded_hardware + + model_io = try_stage(failures, "Load model inputs", load_model_inputs, model_path, args.seed) + if model_io is not None: + inputs_desc, outputs_desc, arrays_in_order, runtime_inputs = model_io + + expected_network_mlir = out_dir / "reference" / f"{model_path.stem}.onnx.mlir" + expected_runner_path = out_dir / "runner" / "build" / "runner" + + reference_compile = try_stage( + failures, + "Compile reference", + compile_reference, + args, + model_path, + out_dir, + steps, + ) + if reference_compile is not None: + network_mlir, _, runner_path = reference_compile + else: + if expected_network_mlir.exists(): + network_mlir = expected_network_mlir + print(f"\n[Continue] Reusing partial ONNX MLIR: {network_mlir}") + if expected_runner_path.exists(): + runner_path = expected_runner_path + print(f"\n[Continue] Reusing partial runner: {runner_path}") + + if runner_path is not None and runner_path.exists() and model_io is not None: + generated_reference = try_stage( + failures, + "Run reference", + generate_reference_outputs, + runner_path, + runner_path.parent, + model_path, + arrays_in_order, + steps, + args, + out_dir, + ) + if generated_reference is not None: + reference_dir = generated_reference + else: + record_failure( + failures, + "Skip reference outputs", + "Reference outputs were skipped because the native runner or model inputs are not available.", + ) + + if network_mlir is not None and network_mlir.exists() and hardware["core_count"] > 0: + compiled_raptor = try_stage( + failures, + "Compile Raptor PIM", + compile_raptor_target, + network_mlir, + out_dir / "raptor", + hardware, + args, + steps, + ) + if compiled_raptor is not None: + raptor_pim_dir, raptor_pass_timings = compiled_raptor + else: + record_failure( + failures, + "Skip Raptor PIM compile", + "Raptor PIM compile was skipped because the ONNX MLIR or hardware configuration is not available.", + ) + + if raptor_pim_dir is not None: + wrote_inputs = try_stage_success( + failures, + "Write Raptor inputs", + write_inputs_to_memory_bin, + raptor_pim_dir / "memory.bin", + raptor_pim_dir / "config.json", + runtime_inputs, + ) + if wrote_inputs and reference_dir is not None and outputs_desc: + validation = try_stage( + failures, + "Rust Validation Raptor", + run_rust_validation, + "Rust Validation Raptor", + raptor_pim_dir, + raptor_pim_dir / "config.json", + outputs_desc, + reference_dir, + steps, + args, + ) + raptor_validation = validation if validation is not None else failed_validation("Raptor validation failed") + elif reference_dir is None: + raptor_validation = skipped_validation("Reference outputs are not available") + elif not outputs_desc: + raptor_validation = skipped_validation("Output descriptors are not available") + else: + raptor_validation = skipped_validation("Raptor input materialization failed") + else: + raptor_validation = skipped_validation("Raptor PIM compilation did not produce a PIM directory") + + compiled_pimcomp = try_stage( + failures, + "Compile PIMCOMP", + compile_pimcomp, + args, + model_path, + out_dir / "pimcomp", + steps, + ) + if compiled_pimcomp is not None: + verification_info, simulation_info = compiled_pimcomp + + if verification_info is not None and simulation_info is not None and model_io is not None: + exported = try_stage( + failures, + "Export PIMCOMP for Rust", + export_pimcomp_for_rust, + model_path, + verification_info, + simulation_info, + runtime_inputs, + out_dir / "pimcomp_exported", + ) + if exported is not None: + pimcomp_export_dir = exported + elif verification_info is None or simulation_info is None: + record_failure( + failures, + "Skip PIMCOMP Rust export", + "PIMCOMP Rust export was skipped because PIMCOMP did not produce VerificationInfo.json and SimulationInfo.gz.", + ) + else: + record_failure( + failures, + "Skip PIMCOMP Rust export", + "PIMCOMP Rust export was skipped because model inputs are not available.", + ) + + if pimcomp_export_dir is not None and reference_dir is not None and outputs_desc: + validation = try_stage( + failures, + "Rust Validation PIMCOMP", + run_rust_validation, + "Rust Validation PIMCOMP", + pimcomp_export_dir, + pimcomp_export_dir / "config.json", + outputs_desc, + reference_dir, + steps, + args, + ) + pimcomp_validation = validation if validation is not None else failed_validation("PIMCOMP validation failed") + elif pimcomp_export_dir is None: + pimcomp_validation = skipped_validation("PIMCOMP Rust export is not available") + elif reference_dir is None: + pimcomp_validation = skipped_validation("Reference outputs are not available") + else: + pimcomp_validation = skipped_validation("Output descriptors are not available") + + if hardware["core_count"] > 0: + written_config = try_stage( + failures, + "Write pimsim-nn config", + write_pimsim_config, + args, + out_dir / "pimsim_config", + hardware, + ) + if written_config is not None: + pimsim_config = written_config + else: + record_failure( + failures, + "Skip pimsim-nn config", + "pimsim-nn config was skipped because the hardware configuration is not available.", + ) + + if args.skip_pimsim_nn: + raptor_perf = skipped_perf("Skipped by --skip-pimsim-nn") + pimcomp_perf = skipped_perf("Skipped by --skip-pimsim-nn") + elif pimsim_config is None: + raptor_perf = skipped_perf("pimsim-nn config is not available") + pimcomp_perf = skipped_perf("pimsim-nn config is not available") + else: + if raptor_pim_dir is not None: + perf = try_stage( + failures, + "pimsim-nn Raptor", + run_pimsim_nn, + "pimsim-nn Raptor", + raptor_pim_dir, + pimsim_config, + False, + steps, + args, + ) + raptor_perf = perf if perf is not None else failed_perf("pimsim-nn Raptor failed") + else: + raptor_perf = skipped_perf("Raptor PIM directory is not available") + + if simulation_info is not None: + pimcomp_pimsim_dir = try_stage( + failures, + "Export PIMCOMP for pimsim-nn", + export_pimcomp_for_pimsim_nn, + simulation_info, + out_dir / "pimcomp_pimsim_nn", + ) + if pimcomp_pimsim_dir is not None: + perf = try_stage( + failures, + "pimsim-nn PIMCOMP", + run_pimsim_nn, + "pimsim-nn PIMCOMP", + pimcomp_pimsim_dir, + pimsim_config, + False, + steps, + args, + ) + pimcomp_perf = perf if perf is not None else failed_perf("pimsim-nn PIMCOMP failed") + else: + pimcomp_perf = failed_perf("PIMCOMP pimsim-nn export failed") + else: + pimcomp_perf = skipped_perf("PIMCOMP SimulationInfo.gz is not available") + + if raptor_pim_dir is not None and raptor_pim_dir.exists(): + parsed = try_stage(failures, "Parse Raptor instructions", parse_raptor_instructions, raptor_pim_dir) + raptor_instr = parsed if parsed is not None else empty_instruction_summary(error="Failed to parse Raptor instructions") + else: + raptor_instr = empty_instruction_summary("Raptor PIM directory is not available") + + if simulation_info is not None and simulation_info.exists(): + parsed = try_stage(failures, "Parse PIMCOMP instructions", parse_pimcomp_instructions, simulation_info) + pimcomp_instr = parsed if parsed is not None else empty_instruction_summary(error="Failed to parse PIMCOMP instructions") + else: + pimcomp_instr = empty_instruction_summary("PIMCOMP SimulationInfo.gz is not available") + + report_path = out_dir / "comparison_report.md" + write_report( + report_path, + model_path=model_path, + hardware=hardware, + steps=steps, + failures=failures, + raptor_validation=raptor_validation, + pimcomp_validation=pimcomp_validation, + raptor_perf=raptor_perf, + pimcomp_perf=pimcomp_perf, + raptor_instr=raptor_instr, + pimcomp_instr=pimcomp_instr, + raptor_pass_timings=raptor_pass_timings, + pimsim_mode=args.pimsim_mode, + ) + + json_report = { + "model": str(model_path), + "hardware": hardware, + "pimsim_mode": args.pimsim_mode, + "failures": failures, + "steps": [asdict(step) for step in steps], + "raptor_validation": asdict(raptor_validation), + "pimcomp_validation": asdict(pimcomp_validation), + "raptor_performance": raptor_perf, + "pimcomp_performance": pimcomp_perf, + "raptor_instruction_summary": raptor_instr, + "pimcomp_instruction_summary": pimcomp_instr, + "raptor_pass_timings": raptor_pass_timings, + "paths": { + "reference_outputs": optional_path(reference_dir), + "raptor_pim": optional_path(raptor_pim_dir), + "pimcomp_simulation_info": optional_path(simulation_info), + "pimcomp_exported_pim": optional_path(pimcomp_export_dir), + "pimsim_config": optional_path(pimsim_config), + "report_markdown": str(report_path), + }, + } + json_path = out_dir / "comparison_report.json" + with open(json_path, "w", encoding="utf-8") as f: + json.dump(json_report, f, indent=2) + f.write("\n") + + print(f"\n[Done]") + print(f" Report: {report_path}") + print(f" JSON: {json_path}") + if failures or any(step.status != "passed" for step in steps): + print(f" Completed with {len(failures)} recorded failure/skipped stage(s).") + + if args.fail_on_error and (failures or any(step.status != "passed" for step in steps)): + raise SystemExit(1) + + +if __name__ == "__main__": + main()