From 568fd905424ce108395619b72e2a8b157fb958d3 Mon Sep 17 00:00:00 2001 From: NiccoloN Date: Thu, 25 Jun 2026 18:57:12 +0200 Subject: [PATCH] cose --- .serena/project.yml | 134 --------- .../BatchCoreLoweringPatterns.cpp | 195 +++++++++++- .../SpatialToPim/CoreLoweringPatterns.cpp | 96 ++++++ src/PIM/Conversion/SpatialToPim/Patterns.cpp | 106 +++++++ .../MaterializeMergeSchedule.cpp | 277 ++++++++++++++++-- validation/gen_network_runner.py | 18 +- 6 files changed, 647 insertions(+), 179 deletions(-) delete mode 100644 .serena/project.yml diff --git a/.serena/project.yml b/.serena/project.yml deleted file mode 100644 index 06c4bd3..0000000 --- a/.serena/project.yml +++ /dev/null @@ -1,134 +0,0 @@ -# the name by which the project can be referenced within Serena -project_name: raptor - -# list of languages for which language servers are started; choose from: -# al angular ansible bash clojure -# cpp cpp_ccls crystal csharp csharp_omnisharp -# dart elixir elm erlang fortran -# fsharp go groovy haskell haxe -# hlsl html java json julia -# kotlin lean4 lua luau markdown -# matlab msl nix ocaml pascal -# perl php php_phpactor powershell python -# python_jedi python_ty r rego ruby -# ruby_solargraph rust scala scss solidity -# svelte swift systemverilog terraform toml -# typescript typescript_vts vue yaml zig -# (This list may be outdated. For the current list, see values of Language enum here: -# https://github.com/oraios/serena/blob/main/src/solidlsp/ls_config.py -# For some languages, there are alternative language servers, e.g. csharp_omnisharp, ruby_solargraph.) -# Note: -# - For C, use cpp -# - For JavaScript, use typescript -# - For Angular projects, use angular (subsumes typescript+html; requires `npm install` in the project root) -# - For Svelte projects, use svelte (subsumes typescript/javascript for .svelte projects; requires npm) -# - For SCSS / Sass / plain CSS, use scss (some-sass-language-server handles all three) -# - For Free Pascal/Lazarus, use pascal -# Special requirements: -# Some languages require additional setup/installations. -# See here for details: https://oraios.github.io/serena/01-about/020_programming-languages.html#language-servers -# When using multiple languages, the first language server that supports a given file will be used for that file. -# The first language is the default language and the respective language server will be used as a fallback. -# Note that when using the JetBrains backend, language servers are not used and this list is correspondingly ignored. -languages: -- cpp -- rust -- python - -# the encoding used by text files in the project -# For a list of possible encodings, see https://docs.python.org/3.11/library/codecs.html#standard-encodings -encoding: utf-8 - -# list of additional paths to ignore in this project. -# Same syntax as gitignore, so you can use * and **. -# Note: global ignored_paths from serena_config.yml are also applied additively. -ignored_paths: - -# list of mode names that are to be activated by default, overriding the setting in the global configuration. -# The full set of modes to be activated is base_modes (from global config) + default_modes + added_modes. -# If the setting is undefined/empty, the default_modes from the global configuration (serena_config.yml) apply. -# Otherwise, this overrides the setting from the global configuration (serena_config.yml). -# Therefore, you can set this to [] if you do not want the default modes defined in the global config to apply -# for this project. -# This setting can, in turn, be overridden by CLI parameters (--mode). -# See https://oraios.github.io/serena/02-usage/050_configuration.html#modes -default_modes: - -# list of mode names to be activated additionally for this project, e.g. ["query-projects"] -# The full set of modes to be activated is base_modes (from global config) + default_modes + added_modes. -# See https://oraios.github.io/serena/02-usage/050_configuration.html#modes -added_modes: - -# list of tool names to exclude. -# This extends the existing exclusions (e.g. from the global configuration) -# Find the list of tools here: https://oraios.github.io/serena/01-about/035_tools.html -excluded_tools: [] - -# list of tools to include that would otherwise be disabled (particularly optional tools that are disabled by default). -# This extends the existing inclusions (e.g. from the global configuration). -# Find the list of tools here: https://oraios.github.io/serena/01-about/035_tools.html -included_optional_tools: [] - -# fixed set of tools to use as the base tool set (if non-empty), replacing Serena's default set of tools. -# This cannot be combined with non-empty excluded_tools or included_optional_tools. -# Find the list of tools here: https://oraios.github.io/serena/01-about/035_tools.html -fixed_tools: [] - -# time budget (seconds) per tool call for the retrieval of additional symbol information -# such as docstrings or parameter information. -# This overrides the corresponding setting in the global configuration; see the documentation there. -# If null or missing, use the setting from the global configuration. -symbol_info_budget: - -# The language backend to use for this project. -# If not set, the global setting from serena_config.yml is used. -# Valid values: LSP, JetBrains -# Note: the backend is fixed at startup. If a project with a different backend -# is activated post-init, an error will be returned. -language_backend: - -# line ending convention to use when writing source files. -# Possible values: unset (use global setting), "lf", "crlf", or "native" (platform default) -# This does not affect Serena's own files (e.g. memories and configuration files), which always use native line endings. -line_ending: - -# list of regex patterns which, when matched, mark a memory entry as read‑only. -# Extends the list from the global configuration, merging the two lists. -read_only_memory_patterns: [] - -# list of regex patterns for memories to completely ignore. -# Matching memories will not appear in list_memories or activate_project output -# and cannot be accessed via read_memory or write_memory. -# To access ignored memory files, use the read_file tool on the raw file path. -# Extends the list from the global configuration, merging the two lists. -# Example: ["_archive/.*", "_episodes/.*"] -ignored_memory_patterns: [] - -# advanced configuration option allowing to configure language server-specific options. -# Maps the language key to the options. -# Have a look at the docstring of the constructors of the LS implementations within solidlsp (e.g., for C# or PHP) to see which options are available. -# No documentation on options means no options are available. -ls_specific_settings: {} - -# list of additional workspace folder paths for cross-package reference support (e.g. in monorepos). -# Paths can be absolute or relative to the project root. -# Each folder is registered as an LSP workspace folder, enabling language servers to discover -# symbols and references across package boundaries. -# Currently supported for: TypeScript. -# Example: -# additional_workspace_folders: -# - ../sibling-package -# - ../shared-lib -additional_workspace_folders: [] - -# whether the project is in read-only mode -# If set to true, all editing tools will be disabled and attempts to use them will result in an error -# Added on 2025-04-18 -read_only: false - -# whether to use project's .gitignore files to ignore files -ignore_all_files_in_gitignore: true - -# initial prompt for the project. It will always be given to the LLM upon activating the project -# (contrary to the memories, which are loaded on demand). -initial_prompt: '' diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index d262ea5..4a3b01b 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -5,6 +5,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" +#include #include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/SpatialToPim/SpatialToPimPass.hpp" @@ -138,26 +139,49 @@ static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value ba return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult(); } +static SmallVector getStaticIndexAttrs(Builder& builder, ArrayRef values) { + SmallVector attrs; + attrs.reserve(values.size()); + for (int64_t value : values) + attrs.push_back(builder.getIndexAttr(value)); + return attrs; +} + +static SmallVector getUnitStrides(Builder& builder, int64_t rank) { + SmallVector strides; + strides.reserve(rank); + for (int64_t dim = 0; dim < rank; ++dim) + strides.push_back(builder.getIndexAttr(1)); + return strides; +} + static Value createHostTargetOffset(IRRewriter& rewriter, - tensor::ParallelInsertSliceOp insertSlice, + Location loc, ShapedType destinationType, + ArrayRef mixedOffsets, + ArrayRef additionalOffsets, IRMapping& mapper) { int64_t elementBytes = static_cast(getElementTypeSizeInBytes(destinationType.getElementType())); SmallVector strides = computeRowMajorStrides(destinationType.getShape()); Value totalOffset; - Location loc = insertSlice.getLoc(); - for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) { + for (auto [dim, offset] : llvm::enumerate(mixedOffsets)) { int64_t scale = strides[dim] * elementBytes; Value scaledOffset; if (auto attr = dyn_cast(offset)) { auto intAttr = dyn_cast(attr); assert(intAttr && "expected integer offset attribute"); - scaledOffset = - getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), intAttr.getInt() * scale); - } - else { + scaledOffset = getOrCreateIndexConstant(rewriter, + rewriter.getInsertionBlock()->getParentOp(), + (intAttr.getInt() + additionalOffsets[dim]) * scale); + } else { scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast(offset)), scale); + if (additionalOffsets[dim] != 0) { + Value staticOffset = getOrCreateIndexConstant(rewriter, + rewriter.getInsertionBlock()->getParentOp(), + additionalOffsets[dim] * scale); + scaledOffset = arith::AddIOp::create(rewriter, loc, scaledOffset, staticOffset).getResult(); + } } totalOffset = @@ -169,6 +193,127 @@ static Value createHostTargetOffset(IRRewriter& rewriter, return totalOffset; } +static Value createHostTargetOffset(IRRewriter& rewriter, + tensor::ParallelInsertSliceOp insertSlice, + ShapedType destinationType, + IRMapping& mapper) { + SmallVector zeroOffsets(destinationType.getRank(), 0); + return createHostTargetOffset(rewriter, + insertSlice.getLoc(), + destinationType, + insertSlice.getMixedOffsets(), + zeroOffsets, + mapper); +} + +static SmallVector buildFragmentOffsets(IRRewriter& rewriter, + Location loc, + ArrayRef baseOffsets, + ArrayRef fragmentOffsets, + IRMapping& mapper) { + SmallVector combined; + combined.reserve(fragmentOffsets.size()); + for (auto [dim, baseOffset] : llvm::enumerate(baseOffsets)) { + if (auto attr = dyn_cast(baseOffset)) { + int64_t base = cast(attr).getInt(); + combined.push_back(rewriter.getIndexAttr(base + fragmentOffsets[dim])); + continue; + } + + Value dynamicBase = mapper.lookupOrDefault(cast(baseOffset)); + if (fragmentOffsets[dim] == 0) { + combined.push_back(dynamicBase); + continue; + } + + Value staticOffset = + getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), fragmentOffsets[dim]); + combined.push_back(arith::AddIOp::create(rewriter, loc, dynamicBase, staticOffset).getResult()); + } + return combined; +} + +static FailureOr lowerFragmentAssemblyHostCopies(IRRewriter& rewriter, + spatial::SpatReconciliatorOp reconciliator, + Value hostTarget, + ArrayRef baseOffsets, + IRMapping& mapper) { + auto hostTargetType = dyn_cast(hostTarget.getType()); + auto resultType = dyn_cast(reconciliator.getOutput().getType()); + if (!hostTargetType || !resultType || !resultType.hasStaticShape()) + return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor results"); + + std::optional> operandIndicesAttr = reconciliator.getFragmentOperandIndices(); + std::optional> fragmentStridesAttr = reconciliator.getFragmentStrides(); + if (!operandIndicesAttr || !fragmentStridesAttr) + return reconciliator.emitOpError( + "fragment assembly lowering requires explicit operand indices and unit strides"); + + ArrayRef operandIndices = *operandIndicesAttr; + ArrayRef flatOffsets = reconciliator.getFragmentOffsets(); + ArrayRef flatSizes = reconciliator.getFragmentSizes(); + ArrayRef flatStrides = *fragmentStridesAttr; + int64_t rank = resultType.getRank(); + + SmallVector fragmentOperands {reconciliator.getInput()}; + llvm::append_range(fragmentOperands, reconciliator.getFragments()); + + DenseMap packedFragmentOrdinals; + for (int64_t fragmentIndex = 0; fragmentIndex < static_cast(operandIndices.size()); ++fragmentIndex) { + int64_t operandIndex = operandIndices[fragmentIndex]; + if (operandIndex < 0 || operandIndex >= static_cast(fragmentOperands.size())) + return reconciliator.emitOpError("fragment assembly operand index is out of range"); + + SmallVector fragmentOffsets; + int64_t fragmentElements = 1; + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t flatIndex = fragmentIndex * rank + dim; + if (flatStrides[flatIndex] != 1) + return reconciliator.emitOpError("fragment assembly lowering only supports unit strides"); + fragmentOffsets.push_back(flatOffsets[flatIndex]); + fragmentElements *= flatSizes[flatIndex]; + } + + Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]); + auto sourceType = dyn_cast(source.getType()); + if (!sourceType || !sourceType.hasStaticShape()) + return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands"); + + int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++; + SmallVector fragmentShape; + fragmentShape.reserve(rank); + for (int64_t dim = 0; dim < rank; ++dim) + fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]); + + Value fragment = source; + if (llvm::to_vector(sourceType.getShape()) != fragmentShape) { + SmallVector extractOffsets(rank, 0); + extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0]; + fragment = tensor::ExtractSliceOp::create(rewriter, + reconciliator.getLoc(), + source, + getStaticIndexAttrs(rewriter, extractOffsets), + getStaticIndexAttrs(rewriter, fragmentShape), + getUnitStrides(rewriter, rank)); + } + + hostTarget = tensor::InsertSliceOp::create(rewriter, + reconciliator.getLoc(), + fragment, + hostTarget, + buildFragmentOffsets(rewriter, + reconciliator.getLoc(), + baseOffsets, + fragmentOffsets, + mapper), + getStaticIndexAttrs(rewriter, fragmentShape), + getUnitStrides(rewriter, rank)) + .getResult(); + } + + return hostTarget; +} + } // namespace LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp, @@ -207,8 +352,10 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul SmallVector returnOperandIndices; if (computeBatchOp.getNumResults() != 0) { - returnOperandIndices.resize(computeBatchOp.getNumResults()); + returnOperandIndices.resize(computeBatchOp.getNumResults(), std::numeric_limits::max()); for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) { + if (result.use_empty()) + continue; FailureOr returnOperandIndex = getDirectReturnOperandIndex(cast(result)); if (failed(returnOperandIndex)) return computeBatchOp.emitOpError( @@ -271,6 +418,18 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul if (isa(op)) continue; + if (auto reconciliator = dyn_cast(op)) { + std::optional modeAttr = reconciliator.getMode(); + if (modeAttr && *modeAttr == "fragment_assembly") { + for (Operation* user : reconciliator.getOutput().getUsers()) { + if (!isa(user)) + return reconciliator.emitOpError( + "fragment assembly reconciliator lowering expects only tensor.parallel_insert_slice users"); + } + continue; + } + } + if (auto parallelOp = dyn_cast(op)) { auto firstOutputArg = computeBatchOp.getOutputArgument(0); if (!firstOutputArg) @@ -287,10 +446,28 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber(); if (resultIndex >= returnOperandIndices.size()) return insertSlice.emitOpError("result index out of range while lowering host batch output"); + if (returnOperandIndices[resultIndex] == std::numeric_limits::max()) + continue; - Value mappedSource = mapper.lookup(insertSlice.getSource()); Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc()); auto hostTargetType = cast(hostTarget.getType()); + if (auto reconciliator = + insertSlice.getSource().getDefiningOp()) { + std::optional modeAttr = reconciliator.getMode(); + if (modeAttr && *modeAttr == "fragment_assembly") { + FailureOr updatedHostTarget = lowerFragmentAssemblyHostCopies(rewriter, + reconciliator, + hostTarget, + insertSlice.getMixedOffsets(), + mapper); + if (failed(updatedHostTarget)) + return failure(); + hostOutputTensors[resultIndex] = *updatedHostTarget; + continue; + } + } + + Value mappedSource = mapper.lookup(insertSlice.getSource()); Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper); Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0); auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), mappedSource); diff --git a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp index 696b0bc..366a259 100644 --- a/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.cpp @@ -30,6 +30,91 @@ static bool isChannelUseChainOp(Operation* op) { pim::PimTransposeOp>(op); } +static Value createStaticHostTargetOffset(IRRewriter& rewriter, + Location loc, + ShapedType destinationType, + ArrayRef fragmentOffsets) { + int64_t elementBytes = static_cast(getElementTypeSizeInBytes(destinationType.getElementType())); + SmallVector strides = computeRowMajorStrides(destinationType.getShape()); + + int64_t byteOffset = 0; + for (auto [dim, offset] : llvm::enumerate(fragmentOffsets)) + byteOffset += offset * strides[dim] * elementBytes; + return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), byteOffset); +} + +static FailureOr lowerFragmentAssemblyReconciliator(IRRewriter& rewriter, + spatial::SpatReconciliatorOp reconciliator, + IRMapping& mapping) { + auto resultType = dyn_cast(reconciliator.getOutput().getType()); + if (!resultType || !resultType.hasStaticShape()) + return reconciliator.emitOpError("fragment assembly lowering requires a static ranked tensor result"); + + std::optional modeAttr = reconciliator.getMode(); + std::optional> operandIndicesAttr = reconciliator.getFragmentOperandIndices(); + std::optional> fragmentStridesAttr = reconciliator.getFragmentStrides(); + if (!modeAttr || *modeAttr != "fragment_assembly" || !operandIndicesAttr || !fragmentStridesAttr) + return reconciliator.emitOpError("fragment assembly lowering requires explicit fragment metadata"); + + ArrayRef operandIndices = *operandIndicesAttr; + ArrayRef flatOffsets = reconciliator.getFragmentOffsets(); + ArrayRef flatSizes = reconciliator.getFragmentSizes(); + ArrayRef flatStrides = *fragmentStridesAttr; + int64_t rank = resultType.getRank(); + + SmallVector fragmentOperands {reconciliator.getInput()}; + llvm::append_range(fragmentOperands, reconciliator.getFragments()); + + Value currentOutput = createEmptyTensorFromShaped(rewriter, reconciliator.getLoc(), resultType); + DenseMap packedFragmentOrdinals; + for (int64_t fragmentIndex = 0; fragmentIndex < static_cast(operandIndices.size()); ++fragmentIndex) { + int64_t operandIndex = operandIndices[fragmentIndex]; + if (operandIndex < 0 || operandIndex >= static_cast(fragmentOperands.size())) + return reconciliator.emitOpError("fragment assembly operand index is out of range"); + + SmallVector fragmentOffsets; + int64_t fragmentElements = 1; + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t flatIndex = fragmentIndex * rank + dim; + if (flatStrides[flatIndex] != 1) + return reconciliator.emitOpError("fragment assembly lowering only supports unit strides"); + fragmentOffsets.push_back(flatOffsets[flatIndex]); + fragmentElements *= flatSizes[flatIndex]; + } + + Value source = mapping.lookupOrDefault(fragmentOperands[operandIndex]); + auto sourceType = dyn_cast(source.getType()); + if (!sourceType || !sourceType.hasStaticShape()) + return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands"); + + int64_t fragmentBytes = + fragmentElements * static_cast(getElementTypeSizeInBytes(sourceType.getElementType())); + auto sizeAttr = pim::getCheckedI32Attr(rewriter, + reconciliator.getOperation(), + fragmentBytes, + "fragment assembly host copy size"); + if (failed(sizeAttr)) + return failure(); + + int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++; + Value hostTargetOffset = createStaticHostTargetOffset(rewriter, reconciliator.getLoc(), resultType, fragmentOffsets); + Value deviceSourceOffset = getOrCreateIndexConstant(rewriter, + rewriter.getInsertionBlock()->getParentOp(), + packedFragmentOrdinal * fragmentBytes); + currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter, + reconciliator.getLoc(), + currentOutput.getType(), + hostTargetOffset, + deviceSourceOffset, + currentOutput, + source, + *sizeAttr) + .getOutput(); + } + + return currentOutput; +} + static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) { for (Value operand : op->getOperands()) { @@ -131,6 +216,17 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatSchedule mapping.map(*weightArg, weight); } for (Operation& op : block.without_terminator()) { + if (auto reconciliator = dyn_cast(op)) { + std::optional modeAttr = reconciliator.getMode(); + if (modeAttr && *modeAttr == "fragment_assembly") { + auto lowered = lowerFragmentAssemblyReconciliator(rewriter, reconciliator, mapping); + if (failed(lowered)) + return false; + mapping.map(reconciliator.getOutput(), *lowered); + continue; + } + } + cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder); Operation* clonedOp = rewriter.clone(op, mapping); for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults())) diff --git a/src/PIM/Conversion/SpatialToPim/Patterns.cpp b/src/PIM/Conversion/SpatialToPim/Patterns.cpp index e0452e7..d5a57a1 100644 --- a/src/PIM/Conversion/SpatialToPim/Patterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/Patterns.cpp @@ -1,6 +1,10 @@ +#include "mlir/Transforms/DialectConversion.h" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" +#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" +#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" using namespace mlir; @@ -11,6 +15,107 @@ namespace raptor { } // namespace raptor +static SmallVector getStaticIndexAttrs(Builder& builder, ArrayRef values) { + SmallVector attrs; + attrs.reserve(values.size()); + for (int64_t value : values) + attrs.push_back(builder.getIndexAttr(value)); + return attrs; +} + +static SmallVector getUnitStrides(Builder& builder, int64_t rank) { + SmallVector strides; + strides.reserve(rank); + for (int64_t dim = 0; dim < rank; ++dim) + strides.push_back(builder.getIndexAttr(1)); + return strides; +} + +struct LowerFragmentAssemblyReconciliatorPattern + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(spatial::SpatReconciliatorOp op, + OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + std::optional modeAttr = op.getMode(); + if (!modeAttr || *modeAttr != "fragment_assembly") + return failure(); + + auto resultType = dyn_cast(op.getOutput().getType()); + if (!resultType || !resultType.hasStaticShape()) + return op.emitOpError("fragment assembly lowering requires a static ranked tensor result"); + + std::optional> operandIndicesAttr = op.getFragmentOperandIndices(); + std::optional> fragmentStridesAttr = op.getFragmentStrides(); + if (!operandIndicesAttr || !fragmentStridesAttr) + return op.emitOpError("fragment assembly lowering requires explicit fragment metadata"); + + ArrayRef operandIndices = *operandIndicesAttr; + ArrayRef flatOffsets = op.getFragmentOffsets(); + ArrayRef flatSizes = op.getFragmentSizes(); + ArrayRef flatStrides = *fragmentStridesAttr; + int64_t rank = resultType.getRank(); + + SmallVector fragmentOperands {adaptor.getInput()}; + llvm::append_range(fragmentOperands, adaptor.getFragments()); + + Value currentOutput = + tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), resultType.getElementType()).getResult(); + DenseMap packedFragmentOrdinals; + for (int64_t fragmentIndex = 0; fragmentIndex < static_cast(operandIndices.size()); ++fragmentIndex) { + int64_t operandIndex = operandIndices[fragmentIndex]; + if (operandIndex < 0 || operandIndex >= static_cast(fragmentOperands.size())) + return op.emitOpError("fragment assembly operand index is out of range"); + + SmallVector fragmentOffsets; + int64_t fragmentElements = 1; + for (int64_t dim = 0; dim < rank; ++dim) { + int64_t flatIndex = fragmentIndex * rank + dim; + if (flatStrides[flatIndex] != 1) + return op.emitOpError("fragment assembly lowering only supports unit strides"); + fragmentOffsets.push_back(flatOffsets[flatIndex]); + fragmentElements *= flatSizes[flatIndex]; + } + + Value source = fragmentOperands[operandIndex]; + auto sourceType = dyn_cast(source.getType()); + if (!sourceType || !sourceType.hasStaticShape()) + return op.emitOpError("fragment assembly lowering requires static ranked tensor operands"); + + int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++; + SmallVector fragmentShape; + fragmentShape.reserve(rank); + for (int64_t dim = 0; dim < rank; ++dim) + fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]); + + Value fragment = source; + if (llvm::to_vector(sourceType.getShape()) != fragmentShape) { + SmallVector extractOffsets(rank, 0); + extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0]; + fragment = tensor::ExtractSliceOp::create(rewriter, + op.getLoc(), + source, + getStaticIndexAttrs(rewriter, extractOffsets), + getStaticIndexAttrs(rewriter, fragmentShape), + getUnitStrides(rewriter, rank)); + } + + currentOutput = tensor::InsertSliceOp::create(rewriter, + op.getLoc(), + fragment, + currentOutput, + getStaticIndexAttrs(rewriter, fragmentOffsets), + getStaticIndexAttrs(rewriter, fragmentShape), + getUnitStrides(rewriter, rank)) + .getResult(); + } + + rewriter.replaceOp(op, currentOutput); + return success(); + } +}; + void populateInitialPatterns(RewritePatternSet& patterns) { raptor::populateWithGenerated(patterns); populateTransposeLoweringPatterns(patterns); @@ -19,6 +124,7 @@ void populateInitialPatterns(RewritePatternSet& patterns) { void populateCoreBodyPatterns(RewritePatternSet& patterns) { raptor::populateWithGenerated(patterns); populateTransposeLoweringPatterns(patterns); + patterns.add(patterns.getContext()); } } // namespace onnx_mlir diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index 21089bb..dffb185 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -306,10 +306,12 @@ struct ProjectedExtractReplacement { struct PendingProjectedHostOutputFragment { Value originalOutput; ClassId sourceClass = 0; + ProducerKey producerKey; Value operand; RankedTensorType operandType; RankedTensorType fragmentType; int64_t packedFragmentIndex = -1; + int64_t currentLane = -1; SmallVector offsets; SmallVector sizes; SmallVector strides; @@ -1137,6 +1139,59 @@ LogicalResult createEmptyMaterializedOps(MaterializerState& state) { return success(); } +void setInsertionPointForNewMaterializedOp(MaterializerState& state) { + Block& funcBlock = state.func.getBody().front(); + for (Operation& op : funcBlock) { + if (state.oldComputeOps.contains(&op)) { + state.rewriter.setInsertionPoint(&op); + return; + } + } + state.rewriter.setInsertionPointToEnd(&funcBlock); +} + +FailureOr createProjectedHostAssemblyClass(MaterializerState& state, Value originalOutput, Location loc) { + DenseSet usedCpus; + for (const auto& [cpu, _] : state.cpuToClass) + usedCpus.insert(cpu); + + CpuId assemblyCpu = 0; + while (usedCpus.contains(assemblyCpu)) + ++assemblyCpu; + + setInsertionPointForNewMaterializedOp(state); + + auto resultType = dyn_cast(originalOutput.getType()); + if (!resultType || !resultType.hasStaticShape()) + return state.func.emitError("projected host assembly class requires a static ranked tensor output"); + + auto compute = SpatScheduledCompute::create(state.rewriter, loc, TypeRange {resultType}, ValueRange {}, ValueRange {}); + compute.getProperties().setOperandSegmentSizes({0, 0}); + auto coreIdAttr = pim::getCheckedI32Attr(state.rewriter, state.func, assemblyCpu, "projected host assembly core id"); + if (failed(coreIdAttr)) + return failure(); + compute->setAttr(onnx_mlir::kCoreIdAttrName, *coreIdAttr); + + Block* body = state.rewriter.createBlock(&compute.getBody()); + state.rewriter.setInsertionPointToEnd(body); + Value placeholder = + tensor::EmptyOp::create(state.rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult(); + SpatYieldOp::create(state.rewriter, loc, ValueRange {placeholder}); + state.rewriter.setInsertionPointAfter(compute.getOperation()); + + MaterializedClass materializedClass; + materializedClass.id = state.classes.size(); + materializedClass.cpus.push_back(assemblyCpu); + materializedClass.op = compute.getOperation(); + materializedClass.body = body; + materializedClass.hostOutputToResultIndex[originalOutput] = 0; + materializedClass.hostOutputs.push_back(originalOutput); + state.cpuToClass[assemblyCpu] = materializedClass.id; + state.hostOutputOwners[originalOutput] = materializedClass.id; + state.classes.push_back(std::move(materializedClass)); + return state.classes.back().id; +} + BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) { auto it = materializedClass.weightArgs.find(weight); if (it != materializedClass.weightArgs.end()) @@ -1897,6 +1952,14 @@ FailureOr> buildProjectedFragmentOffsetsInClass(Mat return fragmentOffsets; } +SmallVector getStaticIndexAttrs(Builder& builder, ArrayRef values) { + SmallVector attrs; + attrs.reserve(values.size()); + for (int64_t value : values) + attrs.push_back(builder.getIndexAttr(value)); + return attrs; +} + Value createDim0InsertSlice( MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) { auto fragmentType = cast(fragment.getType()); @@ -3639,6 +3702,9 @@ LogicalResult appendSend(MaterializerState& state, if (sourceClass.isBatch) { state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + if (messages.size() != sourceClass.cpus.size()) + return sourceClass.op->emitError("batch send expects exactly one message per materialized lane") + << " messageCount=" << messages.size() << " laneCount=" << sourceClass.cpus.size(); Value channelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc); Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc); @@ -3686,6 +3752,11 @@ Value appendReceive( state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); if (targetClass.isBatch) { + if (messages.size() != targetClass.cpus.size()) { + targetClass.op->emitOpError("batch receive expects exactly one message per materialized lane") + << " messageCount=" << messages.size() << " laneCount=" << targetClass.cpus.size(); + return Value(); + } Value channelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc); Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc); Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc); @@ -5481,10 +5552,12 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment { originalOutput, sourceClass.id, + ProducerKey {peer, resultIndex}, packed, cast(packed.getType()), fragmentType, static_cast(runIndex), + static_cast(runIndex), SmallVector(*offsets), SmallVector(*sizes), SmallVector(*strides), @@ -5572,10 +5645,12 @@ FailureOr recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment { originalOutput, sourceClass.id, + key, packed, packedType, fragmentType, operandIsDim0Packed ? static_cast(fragmentIndex) : -1, + static_cast(fragmentIndex), SmallVector(*offsets), SmallVector(*sizes), SmallVector(*strides), @@ -5611,16 +5686,6 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { } MaterializedClass* ownerClass = &state.classes[ownerIt->second]; - if (ownerClass->isBatch) { - auto scalarOwnerIt = llvm::find_if(state.classes, [](const MaterializedClass& candidate) { - return !candidate.isBatch; - }); - if (scalarOwnerIt == state.classes.end()) - return ownerClass->op->emitError( - "projected host output finalization requires a scalar assembly class when the preferred host owner is batch"); - ownerClass = &*scalarOwnerIt; - state.hostOutputOwners[originalOutput] = ownerClass->id; - } auto resultType = dyn_cast(originalOutput.getType()); if (!resultType || !resultType.hasStaticShape()) @@ -5646,6 +5711,119 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { if (allFromSameSourceClass) { ownerClass = &state.classes[fragments.front()->sourceClass]; state.hostOutputOwners[originalOutput] = ownerClass->id; + } else { + if (!ownerClass->isBatch && ownerClass->hostOutputToResultIndex.contains(originalOutput)) + goto owner_selected; + FailureOr createdOwner = + createProjectedHostAssemblyClass(state, originalOutput, fragments.front()->loc); + if (failed(createdOwner)) + return failure(); + ownerClass = &state.classes[*createdOwner]; + } +owner_selected: + + if (ownerClass->isBatch && allFromSameSourceClass && ownerClass->id == fragments.front()->sourceClass) { + auto sourceBatch = dyn_cast(fragments.front()->producerKey.instance.op); + auto batch = dyn_cast(ownerClass->op); + auto inParallelOp = dyn_cast_or_null(ownerClass->body->getTerminator()); + auto resultIt = ownerClass->hostOutputToResultIndex.find(originalOutput); + if (!sourceBatch || !batch || !inParallelOp || resultIt == ownerClass->hostOutputToResultIndex.end()) + return ownerClass->op->emitError("missing batch host assembly state for projected host output"); + FailureOr sourceProjection = + getBatchResultProjectionInsert(sourceBatch, fragments.front()->producerKey.resultIndex); + std::optional sourceLaneArg = sourceBatch.getLaneArgument(); + if (failed(sourceProjection) || !sourceLaneArg) + return ownerClass->op->emitError( + "direct batch host output assembly requires the source batch projection metadata"); + + auto outputArg = batch.getOutputArgument(resultIt->second); + auto laneArg = batch.getLaneArgument(); + if (!outputArg || !laneArg) + return ownerClass->op->emitError("missing compute_batch output block argument for projected host output"); + + if (fragments.size() != ownerClass->cpus.size()) + return ownerClass->op->emitError( + "direct batch host output assembly expects exactly one fragment per materialized lane"); + + SmallVector fragmentsByLane(ownerClass->cpus.size(), nullptr); + for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) { + int64_t currentLane = fragmentRecord->currentLane >= 0 ? fragmentRecord->currentLane : fragmentRecord->sourceLane; + if (currentLane < 0 || currentLane >= static_cast(fragmentsByLane.size())) + return ownerClass->op->emitError("projected batch host output fragment current lane is out of bounds"); + if (fragmentsByLane[currentLane]) + return ownerClass->op->emitError("projected batch host output has duplicate fragments for one lane"); + fragmentsByLane[currentLane] = fragmentRecord; + } + + if (llvm::any_of(fragmentsByLane, [](PendingProjectedHostOutputFragment* fragment) { return fragment == nullptr; })) + return ownerClass->op->emitError("projected batch host output is missing a fragment for one or more lanes"); + + FailureOr> firstSizes = + evaluateStaticProjectionIndices(sourceProjection->getMixedSizes(), *sourceLaneArg, fragmentsByLane.front()->sourceLane); + FailureOr> firstStrides = + evaluateStaticProjectionIndices(sourceProjection->getMixedStrides(), *sourceLaneArg, fragmentsByLane.front()->sourceLane); + if (failed(firstSizes) || failed(firstStrides)) + return ownerClass->op->emitError("failed to evaluate direct batch host output fragment shape"); + SmallVector referenceSizes(*firstSizes); + SmallVector referenceStrides(*firstStrides); + Value laneOperand; + for (PendingProjectedHostOutputFragment* fragmentRecord : fragmentsByLane) { + FailureOr> fragmentSizes = + evaluateStaticProjectionIndices(sourceProjection->getMixedSizes(), *sourceLaneArg, fragmentRecord->sourceLane); + FailureOr> fragmentStrides = + evaluateStaticProjectionIndices(sourceProjection->getMixedStrides(), *sourceLaneArg, fragmentRecord->sourceLane); + if (failed(fragmentSizes) || failed(fragmentStrides)) + return ownerClass->op->emitError("failed to evaluate direct batch host output fragment shape"); + if (SmallVector(*fragmentSizes) != referenceSizes + || SmallVector(*fragmentStrides) != referenceStrides) + return ownerClass->op->emitError( + "direct batch host output assembly expects a uniform fragment shape and strides"); + + MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass]; + Value operand; + if (std::optional availableValue = + state.availableValues.lookup(state, fragmentRecord->producerKey, sourceClass.id)) { + operand = *availableValue; + } else { + operand = fragmentRecord->operand; + } + if (!isValueLegalInMaterializedClassBody(operand, *ownerClass)) + return ownerClass->op->emitError( + "projected batch host output assembly requires source-local fragment operands"); + if (laneOperand && laneOperand != operand) + return ownerClass->op->emitError( + "direct batch host output assembly expects one shared lane-local fragment producer"); + laneOperand = operand; + } + + SmallVector mixedOffsets; + mixedOffsets.reserve(referenceSizes.size()); + for (size_t dim = 0; dim < referenceSizes.size(); ++dim) { + SmallVector offsetsByLane; + offsetsByLane.reserve(fragmentsByLane.size()); + for (PendingProjectedHostOutputFragment* fragmentRecord : fragmentsByLane) { + FailureOr> fragmentOffsets = + evaluateStaticProjectionIndices(sourceProjection->getMixedOffsets(), *sourceLaneArg, fragmentRecord->sourceLane); + if (failed(fragmentOffsets)) + return ownerClass->op->emitError("failed to evaluate direct batch host output fragment offsets"); + offsetsByLane.push_back((*fragmentOffsets)[dim]); + } + mixedOffsets.push_back(allEqual(offsetsByLane) + ? OpFoldResult(state.rewriter.getIndexAttr(offsetsByLane.front())) + : OpFoldResult(createLaneIndexedIndexValue( + state, *ownerClass, ArrayRef(offsetsByLane), fragments.front()->loc))); + } + + state.hostReplacements[originalOutput] = ownerClass->op->getResult(resultIt->second); + state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); + tensor::ParallelInsertSliceOp::create(state.rewriter, + fragments.front()->loc, + laneOperand, + *outputArg, + mixedOffsets, + getStaticIndexAttrs(state.rewriter, referenceSizes), + getStaticIndexAttrs(state.rewriter, referenceStrides)); + continue; } state.rewriter.setInsertionPoint(ownerClass->body->getTerminator()); @@ -5656,28 +5834,73 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) { SmallVector flatSizes; SmallVector flatStrides; DenseMap operandIndicesByValue; + DenseSet emittedBatchForwarding; for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) { - Value operand = fragmentRecord->operand; MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass]; + Value operand; + + if (std::optional availableValue = + state.availableValues.lookup(state, fragmentRecord->producerKey, sourceClass.id)) { + operand = *availableValue; + } else if (fragmentRecord->sourceClass == sourceClass.id) { + operand = fragmentRecord->operand; + } else { + return sourceClass.op->emitError( + "projected host output fragment assembly is missing source-visible fragment operands before finalization"); + } if (fragmentRecord->sourceClass != ownerClass->id) { - if (sourceClass.isBatch || ownerClass->isBatch) - return sourceClass.op->emitError( - "projected host output fragment assembly requires scalarized cross-class operands before finalization"); - MessageVector messages; - auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, - sourceClass.cpus.front(), - "projected host output source core id"); - auto checkedTargetCpu = getCheckedCoreId(ownerClass->op, - ownerClass->cpus.front(), - "projected host output target core id"); - if (failed(checkedSourceCpu) || failed(checkedTargetCpu)) - return failure(); - messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); - if (failed(appendSend(state, sourceClass, operand, messages, fragmentRecord->loc))) - return failure(); - operand = appendReceive(state, *ownerClass, fragmentRecord->operandType, messages, fragmentRecord->loc); + if (sourceClass.isBatch && !ownerClass->isBatch) { + if (!emittedBatchForwarding.insert(sourceClass.id).second) { + std::optional localized = state.availableValues.lookup(state, fragmentRecord->producerKey, ownerClass->id); + if (!localized) + return ownerClass->op->emitError( + "projected host output fragment assembly is missing forwarded batch fragments"); + operand = *localized; + } else { + SmallVector forwardedKeys; + forwardedKeys.reserve(sourceClass.cpus.size()); + Value forwardedPayload = fragmentRecord->operand; + for (PendingProjectedHostOutputFragment* candidate : fragments) { + if (candidate->sourceClass != sourceClass.id) + continue; + if (candidate->operand != forwardedPayload) + return ownerClass->op->emitError( + "projected host output batch forwarding expects one shared batch payload per source class"); + forwardedKeys.push_back(candidate->producerKey); + } + llvm::sort(forwardedKeys, [](ProducerKey lhs, ProducerKey rhs) { + return lhs.instance.laneStart < rhs.instance.laneStart; + }); + if (failed(emitClassToClassCommunication( + state, sourceClass, *ownerClass, forwardedKeys, forwardedPayload, fragmentRecord->loc))) + return failure(); + std::optional localized = state.availableValues.lookup(state, fragmentRecord->producerKey, ownerClass->id); + if (!localized) + return ownerClass->op->emitError( + "projected host output fragment assembly failed to recover forwarded batch fragment"); + operand = *localized; + } + } else { + MessageVector messages; + auto checkedSourceCpu = getCheckedCoreId(sourceClass.op, + sourceClass.cpus.front(), + "projected host output source core id"); + auto checkedTargetCpu = getCheckedCoreId(ownerClass->op, + ownerClass->cpus.front(), + "projected host output target core id"); + if (failed(checkedSourceCpu) || failed(checkedTargetCpu)) + return failure(); + messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu); + if (failed(appendSend(state, sourceClass, operand, messages, fragmentRecord->loc))) + return failure(); + operand = appendReceive(state, + *ownerClass, + cast(operand.getType()), + messages, + fragmentRecord->loc); + } } else if (!ownerClass->isBatch) { FailureOr localOperand = materializeTensorValueForMaterializedClassUse( state, diff --git a/validation/gen_network_runner.py b/validation/gen_network_runner.py index 7f0b731..3706c15 100644 --- a/validation/gen_network_runner.py +++ b/validation/gen_network_runner.py @@ -6,15 +6,15 @@ from onnx import TensorProto # ONNX dtype -> (ctype, printf, ONNX_TYPE_*) DTYPES = { - TensorProto.FLOAT: ("float", "%g", "ONNX_TYPE_FLOAT"), - TensorProto.DOUBLE: ("double", "%g", "ONNX_TYPE_DOUBLE"), - TensorProto.INT64: ("int64_t", "%lld","ONNX_TYPE_INT64"), - TensorProto.INT32: ("int32_t", "%d", "ONNX_TYPE_INT32"), - TensorProto.UINT8: ("uint8_t", "%u", "ONNX_TYPE_UINT8"), - TensorProto.INT8: ("int8_t", "%d", "ONNX_TYPE_INT8"), - TensorProto.BOOL: ("uint8_t", "%u", "ONNX_TYPE_BOOL"), # stored as byte - TensorProto.FLOAT16: ("uint16_t", "%u", "ONNX_TYPE_FLOAT16"), # raw 16-bit - TensorProto.BFLOAT16:("uint16_t", "%u", "ONNX_TYPE_BFLOAT16"), + TensorProto.FLOAT: ("float", "%.9g", "ONNX_TYPE_FLOAT"), + TensorProto.DOUBLE: ("double", "%.17g", "ONNX_TYPE_DOUBLE"), + TensorProto.INT64: ("int64_t", "%lld", "ONNX_TYPE_INT64"), + TensorProto.INT32: ("int32_t", "%d", "ONNX_TYPE_INT32"), + TensorProto.UINT8: ("uint8_t", "%u", "ONNX_TYPE_UINT8"), + TensorProto.INT8: ("int8_t", "%d", "ONNX_TYPE_INT8"), + TensorProto.BOOL: ("uint8_t", "%u", "ONNX_TYPE_BOOL"), + TensorProto.FLOAT16: ("uint16_t", "%u", "ONNX_TYPE_FLOAT16"), + TensorProto.BFLOAT16:("uint16_t", "%u", "ONNX_TYPE_BFLOAT16"), } def esc(s): return s.replace("\\","\\\\").replace('"','\\"')