This commit is contained in:
@@ -1,134 +0,0 @@
|
|||||||
# the name by which the project can be referenced within Serena
|
|
||||||
project_name: raptor
|
|
||||||
|
|
||||||
# list of languages for which language servers are started; choose from:
|
|
||||||
# al angular ansible bash clojure
|
|
||||||
# cpp cpp_ccls crystal csharp csharp_omnisharp
|
|
||||||
# dart elixir elm erlang fortran
|
|
||||||
# fsharp go groovy haskell haxe
|
|
||||||
# hlsl html java json julia
|
|
||||||
# kotlin lean4 lua luau markdown
|
|
||||||
# matlab msl nix ocaml pascal
|
|
||||||
# perl php php_phpactor powershell python
|
|
||||||
# python_jedi python_ty r rego ruby
|
|
||||||
# ruby_solargraph rust scala scss solidity
|
|
||||||
# svelte swift systemverilog terraform toml
|
|
||||||
# typescript typescript_vts vue yaml zig
|
|
||||||
# (This list may be outdated. For the current list, see values of Language enum here:
|
|
||||||
# https://github.com/oraios/serena/blob/main/src/solidlsp/ls_config.py
|
|
||||||
# For some languages, there are alternative language servers, e.g. csharp_omnisharp, ruby_solargraph.)
|
|
||||||
# Note:
|
|
||||||
# - For C, use cpp
|
|
||||||
# - For JavaScript, use typescript
|
|
||||||
# - For Angular projects, use angular (subsumes typescript+html; requires `npm install` in the project root)
|
|
||||||
# - For Svelte projects, use svelte (subsumes typescript/javascript for .svelte projects; requires npm)
|
|
||||||
# - For SCSS / Sass / plain CSS, use scss (some-sass-language-server handles all three)
|
|
||||||
# - For Free Pascal/Lazarus, use pascal
|
|
||||||
# Special requirements:
|
|
||||||
# Some languages require additional setup/installations.
|
|
||||||
# See here for details: https://oraios.github.io/serena/01-about/020_programming-languages.html#language-servers
|
|
||||||
# When using multiple languages, the first language server that supports a given file will be used for that file.
|
|
||||||
# The first language is the default language and the respective language server will be used as a fallback.
|
|
||||||
# Note that when using the JetBrains backend, language servers are not used and this list is correspondingly ignored.
|
|
||||||
languages:
|
|
||||||
- cpp
|
|
||||||
- rust
|
|
||||||
- python
|
|
||||||
|
|
||||||
# the encoding used by text files in the project
|
|
||||||
# For a list of possible encodings, see https://docs.python.org/3.11/library/codecs.html#standard-encodings
|
|
||||||
encoding: utf-8
|
|
||||||
|
|
||||||
# list of additional paths to ignore in this project.
|
|
||||||
# Same syntax as gitignore, so you can use * and **.
|
|
||||||
# Note: global ignored_paths from serena_config.yml are also applied additively.
|
|
||||||
ignored_paths:
|
|
||||||
|
|
||||||
# list of mode names that are to be activated by default, overriding the setting in the global configuration.
|
|
||||||
# The full set of modes to be activated is base_modes (from global config) + default_modes + added_modes.
|
|
||||||
# If the setting is undefined/empty, the default_modes from the global configuration (serena_config.yml) apply.
|
|
||||||
# Otherwise, this overrides the setting from the global configuration (serena_config.yml).
|
|
||||||
# Therefore, you can set this to [] if you do not want the default modes defined in the global config to apply
|
|
||||||
# for this project.
|
|
||||||
# This setting can, in turn, be overridden by CLI parameters (--mode).
|
|
||||||
# See https://oraios.github.io/serena/02-usage/050_configuration.html#modes
|
|
||||||
default_modes:
|
|
||||||
|
|
||||||
# list of mode names to be activated additionally for this project, e.g. ["query-projects"]
|
|
||||||
# The full set of modes to be activated is base_modes (from global config) + default_modes + added_modes.
|
|
||||||
# See https://oraios.github.io/serena/02-usage/050_configuration.html#modes
|
|
||||||
added_modes:
|
|
||||||
|
|
||||||
# list of tool names to exclude.
|
|
||||||
# This extends the existing exclusions (e.g. from the global configuration)
|
|
||||||
# Find the list of tools here: https://oraios.github.io/serena/01-about/035_tools.html
|
|
||||||
excluded_tools: []
|
|
||||||
|
|
||||||
# list of tools to include that would otherwise be disabled (particularly optional tools that are disabled by default).
|
|
||||||
# This extends the existing inclusions (e.g. from the global configuration).
|
|
||||||
# Find the list of tools here: https://oraios.github.io/serena/01-about/035_tools.html
|
|
||||||
included_optional_tools: []
|
|
||||||
|
|
||||||
# fixed set of tools to use as the base tool set (if non-empty), replacing Serena's default set of tools.
|
|
||||||
# This cannot be combined with non-empty excluded_tools or included_optional_tools.
|
|
||||||
# Find the list of tools here: https://oraios.github.io/serena/01-about/035_tools.html
|
|
||||||
fixed_tools: []
|
|
||||||
|
|
||||||
# time budget (seconds) per tool call for the retrieval of additional symbol information
|
|
||||||
# such as docstrings or parameter information.
|
|
||||||
# This overrides the corresponding setting in the global configuration; see the documentation there.
|
|
||||||
# If null or missing, use the setting from the global configuration.
|
|
||||||
symbol_info_budget:
|
|
||||||
|
|
||||||
# The language backend to use for this project.
|
|
||||||
# If not set, the global setting from serena_config.yml is used.
|
|
||||||
# Valid values: LSP, JetBrains
|
|
||||||
# Note: the backend is fixed at startup. If a project with a different backend
|
|
||||||
# is activated post-init, an error will be returned.
|
|
||||||
language_backend:
|
|
||||||
|
|
||||||
# line ending convention to use when writing source files.
|
|
||||||
# Possible values: unset (use global setting), "lf", "crlf", or "native" (platform default)
|
|
||||||
# This does not affect Serena's own files (e.g. memories and configuration files), which always use native line endings.
|
|
||||||
line_ending:
|
|
||||||
|
|
||||||
# list of regex patterns which, when matched, mark a memory entry as read‑only.
|
|
||||||
# Extends the list from the global configuration, merging the two lists.
|
|
||||||
read_only_memory_patterns: []
|
|
||||||
|
|
||||||
# list of regex patterns for memories to completely ignore.
|
|
||||||
# Matching memories will not appear in list_memories or activate_project output
|
|
||||||
# and cannot be accessed via read_memory or write_memory.
|
|
||||||
# To access ignored memory files, use the read_file tool on the raw file path.
|
|
||||||
# Extends the list from the global configuration, merging the two lists.
|
|
||||||
# Example: ["_archive/.*", "_episodes/.*"]
|
|
||||||
ignored_memory_patterns: []
|
|
||||||
|
|
||||||
# advanced configuration option allowing to configure language server-specific options.
|
|
||||||
# Maps the language key to the options.
|
|
||||||
# Have a look at the docstring of the constructors of the LS implementations within solidlsp (e.g., for C# or PHP) to see which options are available.
|
|
||||||
# No documentation on options means no options are available.
|
|
||||||
ls_specific_settings: {}
|
|
||||||
|
|
||||||
# list of additional workspace folder paths for cross-package reference support (e.g. in monorepos).
|
|
||||||
# Paths can be absolute or relative to the project root.
|
|
||||||
# Each folder is registered as an LSP workspace folder, enabling language servers to discover
|
|
||||||
# symbols and references across package boundaries.
|
|
||||||
# Currently supported for: TypeScript.
|
|
||||||
# Example:
|
|
||||||
# additional_workspace_folders:
|
|
||||||
# - ../sibling-package
|
|
||||||
# - ../shared-lib
|
|
||||||
additional_workspace_folders: []
|
|
||||||
|
|
||||||
# whether the project is in read-only mode
|
|
||||||
# If set to true, all editing tools will be disabled and attempts to use them will result in an error
|
|
||||||
# Added on 2025-04-18
|
|
||||||
read_only: false
|
|
||||||
|
|
||||||
# whether to use project's .gitignore files to ignore files
|
|
||||||
ignore_all_files_in_gitignore: true
|
|
||||||
|
|
||||||
# initial prompt for the project. It will always be given to the LLM upon activating the project
|
|
||||||
# (contrary to the memories, which are loaded on demand).
|
|
||||||
initial_prompt: ''
|
|
||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||||
#include "mlir/IR/IRMapping.h"
|
#include "mlir/IR/IRMapping.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||||
@@ -138,26 +139,49 @@ static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value ba
|
|||||||
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
|
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
|
||||||
|
SmallVector<OpFoldResult, 4> attrs;
|
||||||
|
attrs.reserve(values.size());
|
||||||
|
for (int64_t value : values)
|
||||||
|
attrs.push_back(builder.getIndexAttr(value));
|
||||||
|
return attrs;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
|
||||||
|
SmallVector<OpFoldResult, 4> strides;
|
||||||
|
strides.reserve(rank);
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim)
|
||||||
|
strides.push_back(builder.getIndexAttr(1));
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
static Value createHostTargetOffset(IRRewriter& rewriter,
|
static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||||
tensor::ParallelInsertSliceOp insertSlice,
|
Location loc,
|
||||||
ShapedType destinationType,
|
ShapedType destinationType,
|
||||||
|
ArrayRef<OpFoldResult> mixedOffsets,
|
||||||
|
ArrayRef<int64_t> additionalOffsets,
|
||||||
IRMapping& mapper) {
|
IRMapping& mapper) {
|
||||||
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
||||||
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
|
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
|
||||||
|
|
||||||
Value totalOffset;
|
Value totalOffset;
|
||||||
Location loc = insertSlice.getLoc();
|
for (auto [dim, offset] : llvm::enumerate(mixedOffsets)) {
|
||||||
for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) {
|
|
||||||
int64_t scale = strides[dim] * elementBytes;
|
int64_t scale = strides[dim] * elementBytes;
|
||||||
Value scaledOffset;
|
Value scaledOffset;
|
||||||
if (auto attr = dyn_cast<Attribute>(offset)) {
|
if (auto attr = dyn_cast<Attribute>(offset)) {
|
||||||
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
||||||
assert(intAttr && "expected integer offset attribute");
|
assert(intAttr && "expected integer offset attribute");
|
||||||
scaledOffset =
|
scaledOffset = getOrCreateIndexConstant(rewriter,
|
||||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), intAttr.getInt() * scale);
|
rewriter.getInsertionBlock()->getParentOp(),
|
||||||
}
|
(intAttr.getInt() + additionalOffsets[dim]) * scale);
|
||||||
else {
|
} else {
|
||||||
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
|
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
|
||||||
|
if (additionalOffsets[dim] != 0) {
|
||||||
|
Value staticOffset = getOrCreateIndexConstant(rewriter,
|
||||||
|
rewriter.getInsertionBlock()->getParentOp(),
|
||||||
|
additionalOffsets[dim] * scale);
|
||||||
|
scaledOffset = arith::AddIOp::create(rewriter, loc, scaledOffset, staticOffset).getResult();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
totalOffset =
|
totalOffset =
|
||||||
@@ -169,6 +193,127 @@ static Value createHostTargetOffset(IRRewriter& rewriter,
|
|||||||
return totalOffset;
|
return totalOffset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||||
|
tensor::ParallelInsertSliceOp insertSlice,
|
||||||
|
ShapedType destinationType,
|
||||||
|
IRMapping& mapper) {
|
||||||
|
SmallVector<int64_t> zeroOffsets(destinationType.getRank(), 0);
|
||||||
|
return createHostTargetOffset(rewriter,
|
||||||
|
insertSlice.getLoc(),
|
||||||
|
destinationType,
|
||||||
|
insertSlice.getMixedOffsets(),
|
||||||
|
zeroOffsets,
|
||||||
|
mapper);
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<OpFoldResult, 4> buildFragmentOffsets(IRRewriter& rewriter,
|
||||||
|
Location loc,
|
||||||
|
ArrayRef<OpFoldResult> baseOffsets,
|
||||||
|
ArrayRef<int64_t> fragmentOffsets,
|
||||||
|
IRMapping& mapper) {
|
||||||
|
SmallVector<OpFoldResult, 4> combined;
|
||||||
|
combined.reserve(fragmentOffsets.size());
|
||||||
|
for (auto [dim, baseOffset] : llvm::enumerate(baseOffsets)) {
|
||||||
|
if (auto attr = dyn_cast<Attribute>(baseOffset)) {
|
||||||
|
int64_t base = cast<IntegerAttr>(attr).getInt();
|
||||||
|
combined.push_back(rewriter.getIndexAttr(base + fragmentOffsets[dim]));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value dynamicBase = mapper.lookupOrDefault(cast<Value>(baseOffset));
|
||||||
|
if (fragmentOffsets[dim] == 0) {
|
||||||
|
combined.push_back(dynamicBase);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value staticOffset =
|
||||||
|
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), fragmentOffsets[dim]);
|
||||||
|
combined.push_back(arith::AddIOp::create(rewriter, loc, dynamicBase, staticOffset).getResult());
|
||||||
|
}
|
||||||
|
return combined;
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Value> lowerFragmentAssemblyHostCopies(IRRewriter& rewriter,
|
||||||
|
spatial::SpatReconciliatorOp reconciliator,
|
||||||
|
Value hostTarget,
|
||||||
|
ArrayRef<OpFoldResult> baseOffsets,
|
||||||
|
IRMapping& mapper) {
|
||||||
|
auto hostTargetType = dyn_cast<RankedTensorType>(hostTarget.getType());
|
||||||
|
auto resultType = dyn_cast<RankedTensorType>(reconciliator.getOutput().getType());
|
||||||
|
if (!hostTargetType || !resultType || !resultType.hasStaticShape())
|
||||||
|
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor results");
|
||||||
|
|
||||||
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
|
||||||
|
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = reconciliator.getFragmentStrides();
|
||||||
|
if (!operandIndicesAttr || !fragmentStridesAttr)
|
||||||
|
return reconciliator.emitOpError(
|
||||||
|
"fragment assembly lowering requires explicit operand indices and unit strides");
|
||||||
|
|
||||||
|
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||||
|
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
|
||||||
|
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
|
||||||
|
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||||
|
int64_t rank = resultType.getRank();
|
||||||
|
|
||||||
|
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
|
||||||
|
llvm::append_range(fragmentOperands, reconciliator.getFragments());
|
||||||
|
|
||||||
|
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
|
||||||
|
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||||
|
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||||
|
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
|
||||||
|
return reconciliator.emitOpError("fragment assembly operand index is out of range");
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> fragmentOffsets;
|
||||||
|
int64_t fragmentElements = 1;
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
|
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||||
|
if (flatStrides[flatIndex] != 1)
|
||||||
|
return reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
|
||||||
|
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||||
|
fragmentElements *= flatSizes[flatIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
Value source = mapper.lookupOrDefault(fragmentOperands[operandIndex]);
|
||||||
|
auto sourceType = dyn_cast<ShapedType>(source.getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
|
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||||
|
|
||||||
|
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
|
||||||
|
SmallVector<int64_t, 4> fragmentShape;
|
||||||
|
fragmentShape.reserve(rank);
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim)
|
||||||
|
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
|
||||||
|
|
||||||
|
Value fragment = source;
|
||||||
|
if (llvm::to_vector(sourceType.getShape()) != fragmentShape) {
|
||||||
|
SmallVector<int64_t, 4> extractOffsets(rank, 0);
|
||||||
|
extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0];
|
||||||
|
fragment = tensor::ExtractSliceOp::create(rewriter,
|
||||||
|
reconciliator.getLoc(),
|
||||||
|
source,
|
||||||
|
getStaticIndexAttrs(rewriter, extractOffsets),
|
||||||
|
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||||
|
getUnitStrides(rewriter, rank));
|
||||||
|
}
|
||||||
|
|
||||||
|
hostTarget = tensor::InsertSliceOp::create(rewriter,
|
||||||
|
reconciliator.getLoc(),
|
||||||
|
fragment,
|
||||||
|
hostTarget,
|
||||||
|
buildFragmentOffsets(rewriter,
|
||||||
|
reconciliator.getLoc(),
|
||||||
|
baseOffsets,
|
||||||
|
fragmentOffsets,
|
||||||
|
mapper),
|
||||||
|
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||||
|
getUnitStrides(rewriter, rank))
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
return hostTarget;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatScheduledComputeBatch computeBatchOp,
|
||||||
@@ -207,8 +352,10 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
|||||||
|
|
||||||
SmallVector<unsigned> returnOperandIndices;
|
SmallVector<unsigned> returnOperandIndices;
|
||||||
if (computeBatchOp.getNumResults() != 0) {
|
if (computeBatchOp.getNumResults() != 0) {
|
||||||
returnOperandIndices.resize(computeBatchOp.getNumResults());
|
returnOperandIndices.resize(computeBatchOp.getNumResults(), std::numeric_limits<unsigned>::max());
|
||||||
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
|
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
|
||||||
|
if (result.use_empty())
|
||||||
|
continue;
|
||||||
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
|
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
|
||||||
if (failed(returnOperandIndex))
|
if (failed(returnOperandIndex))
|
||||||
return computeBatchOp.emitOpError(
|
return computeBatchOp.emitOpError(
|
||||||
@@ -271,6 +418,18 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
|||||||
if (isa<spatial::SpatYieldOp>(op))
|
if (isa<spatial::SpatYieldOp>(op))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
if (auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(op)) {
|
||||||
|
std::optional<StringRef> modeAttr = reconciliator.getMode();
|
||||||
|
if (modeAttr && *modeAttr == "fragment_assembly") {
|
||||||
|
for (Operation* user : reconciliator.getOutput().getUsers()) {
|
||||||
|
if (!isa<tensor::ParallelInsertSliceOp>(user))
|
||||||
|
return reconciliator.emitOpError(
|
||||||
|
"fragment assembly reconciliator lowering expects only tensor.parallel_insert_slice users");
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
|
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
|
||||||
auto firstOutputArg = computeBatchOp.getOutputArgument(0);
|
auto firstOutputArg = computeBatchOp.getOutputArgument(0);
|
||||||
if (!firstOutputArg)
|
if (!firstOutputArg)
|
||||||
@@ -287,10 +446,28 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatSchedul
|
|||||||
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
|
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
|
||||||
if (resultIndex >= returnOperandIndices.size())
|
if (resultIndex >= returnOperandIndices.size())
|
||||||
return insertSlice.emitOpError("result index out of range while lowering host batch output");
|
return insertSlice.emitOpError("result index out of range while lowering host batch output");
|
||||||
|
if (returnOperandIndices[resultIndex] == std::numeric_limits<unsigned>::max())
|
||||||
|
continue;
|
||||||
|
|
||||||
Value mappedSource = mapper.lookup(insertSlice.getSource());
|
|
||||||
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
|
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
|
||||||
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
|
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
|
||||||
|
if (auto reconciliator =
|
||||||
|
insertSlice.getSource().getDefiningOp<spatial::SpatReconciliatorOp>()) {
|
||||||
|
std::optional<StringRef> modeAttr = reconciliator.getMode();
|
||||||
|
if (modeAttr && *modeAttr == "fragment_assembly") {
|
||||||
|
FailureOr<Value> updatedHostTarget = lowerFragmentAssemblyHostCopies(rewriter,
|
||||||
|
reconciliator,
|
||||||
|
hostTarget,
|
||||||
|
insertSlice.getMixedOffsets(),
|
||||||
|
mapper);
|
||||||
|
if (failed(updatedHostTarget))
|
||||||
|
return failure();
|
||||||
|
hostOutputTensors[resultIndex] = *updatedHostTarget;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Value mappedSource = mapper.lookup(insertSlice.getSource());
|
||||||
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
|
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
|
||||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
|
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
|
||||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), mappedSource);
|
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), mappedSource);
|
||||||
|
|||||||
@@ -30,6 +30,91 @@ static bool isChannelUseChainOp(Operation* op) {
|
|||||||
pim::PimTransposeOp>(op);
|
pim::PimTransposeOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value createStaticHostTargetOffset(IRRewriter& rewriter,
|
||||||
|
Location loc,
|
||||||
|
ShapedType destinationType,
|
||||||
|
ArrayRef<int64_t> fragmentOffsets) {
|
||||||
|
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
||||||
|
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
|
||||||
|
|
||||||
|
int64_t byteOffset = 0;
|
||||||
|
for (auto [dim, offset] : llvm::enumerate(fragmentOffsets))
|
||||||
|
byteOffset += offset * strides[dim] * elementBytes;
|
||||||
|
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), byteOffset);
|
||||||
|
}
|
||||||
|
|
||||||
|
static FailureOr<Value> lowerFragmentAssemblyReconciliator(IRRewriter& rewriter,
|
||||||
|
spatial::SpatReconciliatorOp reconciliator,
|
||||||
|
IRMapping& mapping) {
|
||||||
|
auto resultType = dyn_cast<ShapedType>(reconciliator.getOutput().getType());
|
||||||
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
|
return reconciliator.emitOpError("fragment assembly lowering requires a static ranked tensor result");
|
||||||
|
|
||||||
|
std::optional<StringRef> modeAttr = reconciliator.getMode();
|
||||||
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = reconciliator.getFragmentOperandIndices();
|
||||||
|
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = reconciliator.getFragmentStrides();
|
||||||
|
if (!modeAttr || *modeAttr != "fragment_assembly" || !operandIndicesAttr || !fragmentStridesAttr)
|
||||||
|
return reconciliator.emitOpError("fragment assembly lowering requires explicit fragment metadata");
|
||||||
|
|
||||||
|
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||||
|
ArrayRef<int64_t> flatOffsets = reconciliator.getFragmentOffsets();
|
||||||
|
ArrayRef<int64_t> flatSizes = reconciliator.getFragmentSizes();
|
||||||
|
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||||
|
int64_t rank = resultType.getRank();
|
||||||
|
|
||||||
|
SmallVector<Value> fragmentOperands {reconciliator.getInput()};
|
||||||
|
llvm::append_range(fragmentOperands, reconciliator.getFragments());
|
||||||
|
|
||||||
|
Value currentOutput = createEmptyTensorFromShaped(rewriter, reconciliator.getLoc(), resultType);
|
||||||
|
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
|
||||||
|
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||||
|
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||||
|
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
|
||||||
|
return reconciliator.emitOpError("fragment assembly operand index is out of range");
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> fragmentOffsets;
|
||||||
|
int64_t fragmentElements = 1;
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
|
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||||
|
if (flatStrides[flatIndex] != 1)
|
||||||
|
return reconciliator.emitOpError("fragment assembly lowering only supports unit strides");
|
||||||
|
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||||
|
fragmentElements *= flatSizes[flatIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
Value source = mapping.lookupOrDefault(fragmentOperands[operandIndex]);
|
||||||
|
auto sourceType = dyn_cast<ShapedType>(source.getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
|
return reconciliator.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||||
|
|
||||||
|
int64_t fragmentBytes =
|
||||||
|
fragmentElements * static_cast<int64_t>(getElementTypeSizeInBytes(sourceType.getElementType()));
|
||||||
|
auto sizeAttr = pim::getCheckedI32Attr(rewriter,
|
||||||
|
reconciliator.getOperation(),
|
||||||
|
fragmentBytes,
|
||||||
|
"fragment assembly host copy size");
|
||||||
|
if (failed(sizeAttr))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
|
||||||
|
Value hostTargetOffset = createStaticHostTargetOffset(rewriter, reconciliator.getLoc(), resultType, fragmentOffsets);
|
||||||
|
Value deviceSourceOffset = getOrCreateIndexConstant(rewriter,
|
||||||
|
rewriter.getInsertionBlock()->getParentOp(),
|
||||||
|
packedFragmentOrdinal * fragmentBytes);
|
||||||
|
currentOutput = pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||||
|
reconciliator.getLoc(),
|
||||||
|
currentOutput.getType(),
|
||||||
|
hostTargetOffset,
|
||||||
|
deviceSourceOffset,
|
||||||
|
currentOutput,
|
||||||
|
source,
|
||||||
|
*sizeAttr)
|
||||||
|
.getOutput();
|
||||||
|
}
|
||||||
|
|
||||||
|
return currentOutput;
|
||||||
|
}
|
||||||
|
|
||||||
static void
|
static void
|
||||||
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
|
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
|
||||||
for (Value operand : op->getOperands()) {
|
for (Value operand : op->getOperands()) {
|
||||||
@@ -131,6 +216,17 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatSchedule
|
|||||||
mapping.map(*weightArg, weight);
|
mapping.map(*weightArg, weight);
|
||||||
}
|
}
|
||||||
for (Operation& op : block.without_terminator()) {
|
for (Operation& op : block.without_terminator()) {
|
||||||
|
if (auto reconciliator = dyn_cast<spatial::SpatReconciliatorOp>(op)) {
|
||||||
|
std::optional<StringRef> modeAttr = reconciliator.getMode();
|
||||||
|
if (modeAttr && *modeAttr == "fragment_assembly") {
|
||||||
|
auto lowered = lowerFragmentAssemblyReconciliator(rewriter, reconciliator, mapping);
|
||||||
|
if (failed(lowered))
|
||||||
|
return false;
|
||||||
|
mapping.map(reconciliator.getOutput(), *lowered);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
|
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
|
||||||
Operation* clonedOp = rewriter.clone(op, mapping);
|
Operation* clonedOp = rewriter.clone(op, mapping);
|
||||||
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||||
|
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
@@ -11,6 +15,107 @@ namespace raptor {
|
|||||||
|
|
||||||
} // namespace raptor
|
} // namespace raptor
|
||||||
|
|
||||||
|
static SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
|
||||||
|
SmallVector<OpFoldResult, 4> attrs;
|
||||||
|
attrs.reserve(values.size());
|
||||||
|
for (int64_t value : values)
|
||||||
|
attrs.push_back(builder.getIndexAttr(value));
|
||||||
|
return attrs;
|
||||||
|
}
|
||||||
|
|
||||||
|
static SmallVector<OpFoldResult, 4> getUnitStrides(Builder& builder, int64_t rank) {
|
||||||
|
SmallVector<OpFoldResult, 4> strides;
|
||||||
|
strides.reserve(rank);
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim)
|
||||||
|
strides.push_back(builder.getIndexAttr(1));
|
||||||
|
return strides;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct LowerFragmentAssemblyReconciliatorPattern
|
||||||
|
: OpConversionPattern<spatial::SpatReconciliatorOp> {
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(spatial::SpatReconciliatorOp op,
|
||||||
|
OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter& rewriter) const override {
|
||||||
|
std::optional<StringRef> modeAttr = op.getMode();
|
||||||
|
if (!modeAttr || *modeAttr != "fragment_assembly")
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto resultType = dyn_cast<ShapedType>(op.getOutput().getType());
|
||||||
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
|
return op.emitOpError("fragment assembly lowering requires a static ranked tensor result");
|
||||||
|
|
||||||
|
std::optional<ArrayRef<int64_t>> operandIndicesAttr = op.getFragmentOperandIndices();
|
||||||
|
std::optional<ArrayRef<int64_t>> fragmentStridesAttr = op.getFragmentStrides();
|
||||||
|
if (!operandIndicesAttr || !fragmentStridesAttr)
|
||||||
|
return op.emitOpError("fragment assembly lowering requires explicit fragment metadata");
|
||||||
|
|
||||||
|
ArrayRef<int64_t> operandIndices = *operandIndicesAttr;
|
||||||
|
ArrayRef<int64_t> flatOffsets = op.getFragmentOffsets();
|
||||||
|
ArrayRef<int64_t> flatSizes = op.getFragmentSizes();
|
||||||
|
ArrayRef<int64_t> flatStrides = *fragmentStridesAttr;
|
||||||
|
int64_t rank = resultType.getRank();
|
||||||
|
|
||||||
|
SmallVector<Value> fragmentOperands {adaptor.getInput()};
|
||||||
|
llvm::append_range(fragmentOperands, adaptor.getFragments());
|
||||||
|
|
||||||
|
Value currentOutput =
|
||||||
|
tensor::EmptyOp::create(rewriter, op.getLoc(), resultType.getShape(), resultType.getElementType()).getResult();
|
||||||
|
DenseMap<int64_t, int64_t> packedFragmentOrdinals;
|
||||||
|
for (int64_t fragmentIndex = 0; fragmentIndex < static_cast<int64_t>(operandIndices.size()); ++fragmentIndex) {
|
||||||
|
int64_t operandIndex = operandIndices[fragmentIndex];
|
||||||
|
if (operandIndex < 0 || operandIndex >= static_cast<int64_t>(fragmentOperands.size()))
|
||||||
|
return op.emitOpError("fragment assembly operand index is out of range");
|
||||||
|
|
||||||
|
SmallVector<int64_t, 4> fragmentOffsets;
|
||||||
|
int64_t fragmentElements = 1;
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||||
|
int64_t flatIndex = fragmentIndex * rank + dim;
|
||||||
|
if (flatStrides[flatIndex] != 1)
|
||||||
|
return op.emitOpError("fragment assembly lowering only supports unit strides");
|
||||||
|
fragmentOffsets.push_back(flatOffsets[flatIndex]);
|
||||||
|
fragmentElements *= flatSizes[flatIndex];
|
||||||
|
}
|
||||||
|
|
||||||
|
Value source = fragmentOperands[operandIndex];
|
||||||
|
auto sourceType = dyn_cast<RankedTensorType>(source.getType());
|
||||||
|
if (!sourceType || !sourceType.hasStaticShape())
|
||||||
|
return op.emitOpError("fragment assembly lowering requires static ranked tensor operands");
|
||||||
|
|
||||||
|
int64_t packedFragmentOrdinal = packedFragmentOrdinals[operandIndex]++;
|
||||||
|
SmallVector<int64_t, 4> fragmentShape;
|
||||||
|
fragmentShape.reserve(rank);
|
||||||
|
for (int64_t dim = 0; dim < rank; ++dim)
|
||||||
|
fragmentShape.push_back(flatSizes[fragmentIndex * rank + dim]);
|
||||||
|
|
||||||
|
Value fragment = source;
|
||||||
|
if (llvm::to_vector(sourceType.getShape()) != fragmentShape) {
|
||||||
|
SmallVector<int64_t, 4> extractOffsets(rank, 0);
|
||||||
|
extractOffsets[0] = packedFragmentOrdinal * fragmentShape[0];
|
||||||
|
fragment = tensor::ExtractSliceOp::create(rewriter,
|
||||||
|
op.getLoc(),
|
||||||
|
source,
|
||||||
|
getStaticIndexAttrs(rewriter, extractOffsets),
|
||||||
|
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||||
|
getUnitStrides(rewriter, rank));
|
||||||
|
}
|
||||||
|
|
||||||
|
currentOutput = tensor::InsertSliceOp::create(rewriter,
|
||||||
|
op.getLoc(),
|
||||||
|
fragment,
|
||||||
|
currentOutput,
|
||||||
|
getStaticIndexAttrs(rewriter, fragmentOffsets),
|
||||||
|
getStaticIndexAttrs(rewriter, fragmentShape),
|
||||||
|
getUnitStrides(rewriter, rank))
|
||||||
|
.getResult();
|
||||||
|
}
|
||||||
|
|
||||||
|
rewriter.replaceOp(op, currentOutput);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void populateInitialPatterns(RewritePatternSet& patterns) {
|
void populateInitialPatterns(RewritePatternSet& patterns) {
|
||||||
raptor::populateWithGenerated(patterns);
|
raptor::populateWithGenerated(patterns);
|
||||||
populateTransposeLoweringPatterns(patterns);
|
populateTransposeLoweringPatterns(patterns);
|
||||||
@@ -19,6 +124,7 @@ void populateInitialPatterns(RewritePatternSet& patterns) {
|
|||||||
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
|
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
|
||||||
raptor::populateWithGenerated(patterns);
|
raptor::populateWithGenerated(patterns);
|
||||||
populateTransposeLoweringPatterns(patterns);
|
populateTransposeLoweringPatterns(patterns);
|
||||||
|
patterns.add<LowerFragmentAssemblyReconciliatorPattern>(patterns.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace onnx_mlir
|
} // namespace onnx_mlir
|
||||||
|
|||||||
+250
-27
@@ -306,10 +306,12 @@ struct ProjectedExtractReplacement {
|
|||||||
struct PendingProjectedHostOutputFragment {
|
struct PendingProjectedHostOutputFragment {
|
||||||
Value originalOutput;
|
Value originalOutput;
|
||||||
ClassId sourceClass = 0;
|
ClassId sourceClass = 0;
|
||||||
|
ProducerKey producerKey;
|
||||||
Value operand;
|
Value operand;
|
||||||
RankedTensorType operandType;
|
RankedTensorType operandType;
|
||||||
RankedTensorType fragmentType;
|
RankedTensorType fragmentType;
|
||||||
int64_t packedFragmentIndex = -1;
|
int64_t packedFragmentIndex = -1;
|
||||||
|
int64_t currentLane = -1;
|
||||||
SmallVector<int64_t, 4> offsets;
|
SmallVector<int64_t, 4> offsets;
|
||||||
SmallVector<int64_t, 4> sizes;
|
SmallVector<int64_t, 4> sizes;
|
||||||
SmallVector<int64_t, 4> strides;
|
SmallVector<int64_t, 4> strides;
|
||||||
@@ -1137,6 +1139,59 @@ LogicalResult createEmptyMaterializedOps(MaterializerState& state) {
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void setInsertionPointForNewMaterializedOp(MaterializerState& state) {
|
||||||
|
Block& funcBlock = state.func.getBody().front();
|
||||||
|
for (Operation& op : funcBlock) {
|
||||||
|
if (state.oldComputeOps.contains(&op)) {
|
||||||
|
state.rewriter.setInsertionPoint(&op);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.rewriter.setInsertionPointToEnd(&funcBlock);
|
||||||
|
}
|
||||||
|
|
||||||
|
FailureOr<ClassId> createProjectedHostAssemblyClass(MaterializerState& state, Value originalOutput, Location loc) {
|
||||||
|
DenseSet<CpuId> usedCpus;
|
||||||
|
for (const auto& [cpu, _] : state.cpuToClass)
|
||||||
|
usedCpus.insert(cpu);
|
||||||
|
|
||||||
|
CpuId assemblyCpu = 0;
|
||||||
|
while (usedCpus.contains(assemblyCpu))
|
||||||
|
++assemblyCpu;
|
||||||
|
|
||||||
|
setInsertionPointForNewMaterializedOp(state);
|
||||||
|
|
||||||
|
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
|
||||||
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
|
return state.func.emitError("projected host assembly class requires a static ranked tensor output");
|
||||||
|
|
||||||
|
auto compute = SpatScheduledCompute::create(state.rewriter, loc, TypeRange {resultType}, ValueRange {}, ValueRange {});
|
||||||
|
compute.getProperties().setOperandSegmentSizes({0, 0});
|
||||||
|
auto coreIdAttr = pim::getCheckedI32Attr(state.rewriter, state.func, assemblyCpu, "projected host assembly core id");
|
||||||
|
if (failed(coreIdAttr))
|
||||||
|
return failure();
|
||||||
|
compute->setAttr(onnx_mlir::kCoreIdAttrName, *coreIdAttr);
|
||||||
|
|
||||||
|
Block* body = state.rewriter.createBlock(&compute.getBody());
|
||||||
|
state.rewriter.setInsertionPointToEnd(body);
|
||||||
|
Value placeholder =
|
||||||
|
tensor::EmptyOp::create(state.rewriter, loc, resultType.getShape(), resultType.getElementType()).getResult();
|
||||||
|
SpatYieldOp::create(state.rewriter, loc, ValueRange {placeholder});
|
||||||
|
state.rewriter.setInsertionPointAfter(compute.getOperation());
|
||||||
|
|
||||||
|
MaterializedClass materializedClass;
|
||||||
|
materializedClass.id = state.classes.size();
|
||||||
|
materializedClass.cpus.push_back(assemblyCpu);
|
||||||
|
materializedClass.op = compute.getOperation();
|
||||||
|
materializedClass.body = body;
|
||||||
|
materializedClass.hostOutputToResultIndex[originalOutput] = 0;
|
||||||
|
materializedClass.hostOutputs.push_back(originalOutput);
|
||||||
|
state.cpuToClass[assemblyCpu] = materializedClass.id;
|
||||||
|
state.hostOutputOwners[originalOutput] = materializedClass.id;
|
||||||
|
state.classes.push_back(std::move(materializedClass));
|
||||||
|
return state.classes.back().id;
|
||||||
|
}
|
||||||
|
|
||||||
BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) {
|
BlockArgument appendWeight(MaterializerState& state, MaterializedClass& materializedClass, Value weight) {
|
||||||
auto it = materializedClass.weightArgs.find(weight);
|
auto it = materializedClass.weightArgs.find(weight);
|
||||||
if (it != materializedClass.weightArgs.end())
|
if (it != materializedClass.weightArgs.end())
|
||||||
@@ -1897,6 +1952,14 @@ FailureOr<SmallVector<OpFoldResult, 4>> buildProjectedFragmentOffsetsInClass(Mat
|
|||||||
return fragmentOffsets;
|
return fragmentOffsets;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult, 4> getStaticIndexAttrs(Builder& builder, ArrayRef<int64_t> values) {
|
||||||
|
SmallVector<OpFoldResult, 4> attrs;
|
||||||
|
attrs.reserve(values.size());
|
||||||
|
for (int64_t value : values)
|
||||||
|
attrs.push_back(builder.getIndexAttr(value));
|
||||||
|
return attrs;
|
||||||
|
}
|
||||||
|
|
||||||
Value createDim0InsertSlice(
|
Value createDim0InsertSlice(
|
||||||
MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) {
|
MaterializerState& state, Location loc, Value fragment, Value destination, OpFoldResult firstOffset) {
|
||||||
auto fragmentType = cast<RankedTensorType>(fragment.getType());
|
auto fragmentType = cast<RankedTensorType>(fragment.getType());
|
||||||
@@ -3639,6 +3702,9 @@ LogicalResult appendSend(MaterializerState& state,
|
|||||||
|
|
||||||
if (sourceClass.isBatch) {
|
if (sourceClass.isBatch) {
|
||||||
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
state.rewriter.setInsertionPoint(sourceClass.body->getTerminator());
|
||||||
|
if (messages.size() != sourceClass.cpus.size())
|
||||||
|
return sourceClass.op->emitError("batch send expects exactly one message per materialized lane")
|
||||||
|
<< " messageCount=" << messages.size() << " laneCount=" << sourceClass.cpus.size();
|
||||||
|
|
||||||
Value channelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc);
|
Value channelId = createLaneIndexedIndexValue(state, sourceClass, messages.channelIds, loc);
|
||||||
Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc);
|
Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, messages.sourceCoreIds, loc);
|
||||||
@@ -3686,6 +3752,11 @@ Value appendReceive(
|
|||||||
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
state.rewriter.setInsertionPoint(targetClass.body->getTerminator());
|
||||||
|
|
||||||
if (targetClass.isBatch) {
|
if (targetClass.isBatch) {
|
||||||
|
if (messages.size() != targetClass.cpus.size()) {
|
||||||
|
targetClass.op->emitOpError("batch receive expects exactly one message per materialized lane")
|
||||||
|
<< " messageCount=" << messages.size() << " laneCount=" << targetClass.cpus.size();
|
||||||
|
return Value();
|
||||||
|
}
|
||||||
Value channelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc);
|
Value channelId = createLaneIndexedIndexValue(state, targetClass, messages.channelIds, loc);
|
||||||
Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc);
|
Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, messages.sourceCoreIds, loc);
|
||||||
Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc);
|
Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, messages.targetCoreIds, loc);
|
||||||
@@ -5481,10 +5552,12 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedRun(MaterializerStat
|
|||||||
state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment {
|
state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment {
|
||||||
originalOutput,
|
originalOutput,
|
||||||
sourceClass.id,
|
sourceClass.id,
|
||||||
|
ProducerKey {peer, resultIndex},
|
||||||
packed,
|
packed,
|
||||||
cast<RankedTensorType>(packed.getType()),
|
cast<RankedTensorType>(packed.getType()),
|
||||||
fragmentType,
|
fragmentType,
|
||||||
static_cast<int64_t>(runIndex),
|
static_cast<int64_t>(runIndex),
|
||||||
|
static_cast<int64_t>(runIndex),
|
||||||
SmallVector<int64_t, 4>(*offsets),
|
SmallVector<int64_t, 4>(*offsets),
|
||||||
SmallVector<int64_t, 4>(*sizes),
|
SmallVector<int64_t, 4>(*sizes),
|
||||||
SmallVector<int64_t, 4>(*strides),
|
SmallVector<int64_t, 4>(*strides),
|
||||||
@@ -5572,10 +5645,12 @@ FailureOr<bool> recordProjectedScalarHostFragmentsFromPackedValue(MaterializerSt
|
|||||||
state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment {
|
state.pendingProjectedHostOutputFragments.push_back(PendingProjectedHostOutputFragment {
|
||||||
originalOutput,
|
originalOutput,
|
||||||
sourceClass.id,
|
sourceClass.id,
|
||||||
|
key,
|
||||||
packed,
|
packed,
|
||||||
packedType,
|
packedType,
|
||||||
fragmentType,
|
fragmentType,
|
||||||
operandIsDim0Packed ? static_cast<int64_t>(fragmentIndex) : -1,
|
operandIsDim0Packed ? static_cast<int64_t>(fragmentIndex) : -1,
|
||||||
|
static_cast<int64_t>(fragmentIndex),
|
||||||
SmallVector<int64_t, 4>(*offsets),
|
SmallVector<int64_t, 4>(*offsets),
|
||||||
SmallVector<int64_t, 4>(*sizes),
|
SmallVector<int64_t, 4>(*sizes),
|
||||||
SmallVector<int64_t, 4>(*strides),
|
SmallVector<int64_t, 4>(*strides),
|
||||||
@@ -5611,16 +5686,6 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
MaterializedClass* ownerClass = &state.classes[ownerIt->second];
|
MaterializedClass* ownerClass = &state.classes[ownerIt->second];
|
||||||
if (ownerClass->isBatch) {
|
|
||||||
auto scalarOwnerIt = llvm::find_if(state.classes, [](const MaterializedClass& candidate) {
|
|
||||||
return !candidate.isBatch;
|
|
||||||
});
|
|
||||||
if (scalarOwnerIt == state.classes.end())
|
|
||||||
return ownerClass->op->emitError(
|
|
||||||
"projected host output finalization requires a scalar assembly class when the preferred host owner is batch");
|
|
||||||
ownerClass = &*scalarOwnerIt;
|
|
||||||
state.hostOutputOwners[originalOutput] = ownerClass->id;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
|
auto resultType = dyn_cast<RankedTensorType>(originalOutput.getType());
|
||||||
if (!resultType || !resultType.hasStaticShape())
|
if (!resultType || !resultType.hasStaticShape())
|
||||||
@@ -5646,6 +5711,119 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
|||||||
if (allFromSameSourceClass) {
|
if (allFromSameSourceClass) {
|
||||||
ownerClass = &state.classes[fragments.front()->sourceClass];
|
ownerClass = &state.classes[fragments.front()->sourceClass];
|
||||||
state.hostOutputOwners[originalOutput] = ownerClass->id;
|
state.hostOutputOwners[originalOutput] = ownerClass->id;
|
||||||
|
} else {
|
||||||
|
if (!ownerClass->isBatch && ownerClass->hostOutputToResultIndex.contains(originalOutput))
|
||||||
|
goto owner_selected;
|
||||||
|
FailureOr<ClassId> createdOwner =
|
||||||
|
createProjectedHostAssemblyClass(state, originalOutput, fragments.front()->loc);
|
||||||
|
if (failed(createdOwner))
|
||||||
|
return failure();
|
||||||
|
ownerClass = &state.classes[*createdOwner];
|
||||||
|
}
|
||||||
|
owner_selected:
|
||||||
|
|
||||||
|
if (ownerClass->isBatch && allFromSameSourceClass && ownerClass->id == fragments.front()->sourceClass) {
|
||||||
|
auto sourceBatch = dyn_cast<SpatComputeBatch>(fragments.front()->producerKey.instance.op);
|
||||||
|
auto batch = dyn_cast<SpatScheduledComputeBatch>(ownerClass->op);
|
||||||
|
auto inParallelOp = dyn_cast_or_null<SpatInParallelOp>(ownerClass->body->getTerminator());
|
||||||
|
auto resultIt = ownerClass->hostOutputToResultIndex.find(originalOutput);
|
||||||
|
if (!sourceBatch || !batch || !inParallelOp || resultIt == ownerClass->hostOutputToResultIndex.end())
|
||||||
|
return ownerClass->op->emitError("missing batch host assembly state for projected host output");
|
||||||
|
FailureOr<tensor::ParallelInsertSliceOp> sourceProjection =
|
||||||
|
getBatchResultProjectionInsert(sourceBatch, fragments.front()->producerKey.resultIndex);
|
||||||
|
std::optional<BlockArgument> sourceLaneArg = sourceBatch.getLaneArgument();
|
||||||
|
if (failed(sourceProjection) || !sourceLaneArg)
|
||||||
|
return ownerClass->op->emitError(
|
||||||
|
"direct batch host output assembly requires the source batch projection metadata");
|
||||||
|
|
||||||
|
auto outputArg = batch.getOutputArgument(resultIt->second);
|
||||||
|
auto laneArg = batch.getLaneArgument();
|
||||||
|
if (!outputArg || !laneArg)
|
||||||
|
return ownerClass->op->emitError("missing compute_batch output block argument for projected host output");
|
||||||
|
|
||||||
|
if (fragments.size() != ownerClass->cpus.size())
|
||||||
|
return ownerClass->op->emitError(
|
||||||
|
"direct batch host output assembly expects exactly one fragment per materialized lane");
|
||||||
|
|
||||||
|
SmallVector<PendingProjectedHostOutputFragment*, 8> fragmentsByLane(ownerClass->cpus.size(), nullptr);
|
||||||
|
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
|
||||||
|
int64_t currentLane = fragmentRecord->currentLane >= 0 ? fragmentRecord->currentLane : fragmentRecord->sourceLane;
|
||||||
|
if (currentLane < 0 || currentLane >= static_cast<int64_t>(fragmentsByLane.size()))
|
||||||
|
return ownerClass->op->emitError("projected batch host output fragment current lane is out of bounds");
|
||||||
|
if (fragmentsByLane[currentLane])
|
||||||
|
return ownerClass->op->emitError("projected batch host output has duplicate fragments for one lane");
|
||||||
|
fragmentsByLane[currentLane] = fragmentRecord;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (llvm::any_of(fragmentsByLane, [](PendingProjectedHostOutputFragment* fragment) { return fragment == nullptr; }))
|
||||||
|
return ownerClass->op->emitError("projected batch host output is missing a fragment for one or more lanes");
|
||||||
|
|
||||||
|
FailureOr<SmallVector<int64_t, 4>> firstSizes =
|
||||||
|
evaluateStaticProjectionIndices(sourceProjection->getMixedSizes(), *sourceLaneArg, fragmentsByLane.front()->sourceLane);
|
||||||
|
FailureOr<SmallVector<int64_t, 4>> firstStrides =
|
||||||
|
evaluateStaticProjectionIndices(sourceProjection->getMixedStrides(), *sourceLaneArg, fragmentsByLane.front()->sourceLane);
|
||||||
|
if (failed(firstSizes) || failed(firstStrides))
|
||||||
|
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment shape");
|
||||||
|
SmallVector<int64_t, 4> referenceSizes(*firstSizes);
|
||||||
|
SmallVector<int64_t, 4> referenceStrides(*firstStrides);
|
||||||
|
Value laneOperand;
|
||||||
|
for (PendingProjectedHostOutputFragment* fragmentRecord : fragmentsByLane) {
|
||||||
|
FailureOr<SmallVector<int64_t, 4>> fragmentSizes =
|
||||||
|
evaluateStaticProjectionIndices(sourceProjection->getMixedSizes(), *sourceLaneArg, fragmentRecord->sourceLane);
|
||||||
|
FailureOr<SmallVector<int64_t, 4>> fragmentStrides =
|
||||||
|
evaluateStaticProjectionIndices(sourceProjection->getMixedStrides(), *sourceLaneArg, fragmentRecord->sourceLane);
|
||||||
|
if (failed(fragmentSizes) || failed(fragmentStrides))
|
||||||
|
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment shape");
|
||||||
|
if (SmallVector<int64_t, 4>(*fragmentSizes) != referenceSizes
|
||||||
|
|| SmallVector<int64_t, 4>(*fragmentStrides) != referenceStrides)
|
||||||
|
return ownerClass->op->emitError(
|
||||||
|
"direct batch host output assembly expects a uniform fragment shape and strides");
|
||||||
|
|
||||||
|
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
|
||||||
|
Value operand;
|
||||||
|
if (std::optional<Value> availableValue =
|
||||||
|
state.availableValues.lookup(state, fragmentRecord->producerKey, sourceClass.id)) {
|
||||||
|
operand = *availableValue;
|
||||||
|
} else {
|
||||||
|
operand = fragmentRecord->operand;
|
||||||
|
}
|
||||||
|
if (!isValueLegalInMaterializedClassBody(operand, *ownerClass))
|
||||||
|
return ownerClass->op->emitError(
|
||||||
|
"projected batch host output assembly requires source-local fragment operands");
|
||||||
|
if (laneOperand && laneOperand != operand)
|
||||||
|
return ownerClass->op->emitError(
|
||||||
|
"direct batch host output assembly expects one shared lane-local fragment producer");
|
||||||
|
laneOperand = operand;
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<OpFoldResult, 4> mixedOffsets;
|
||||||
|
mixedOffsets.reserve(referenceSizes.size());
|
||||||
|
for (size_t dim = 0; dim < referenceSizes.size(); ++dim) {
|
||||||
|
SmallVector<int64_t, 8> offsetsByLane;
|
||||||
|
offsetsByLane.reserve(fragmentsByLane.size());
|
||||||
|
for (PendingProjectedHostOutputFragment* fragmentRecord : fragmentsByLane) {
|
||||||
|
FailureOr<SmallVector<int64_t, 4>> fragmentOffsets =
|
||||||
|
evaluateStaticProjectionIndices(sourceProjection->getMixedOffsets(), *sourceLaneArg, fragmentRecord->sourceLane);
|
||||||
|
if (failed(fragmentOffsets))
|
||||||
|
return ownerClass->op->emitError("failed to evaluate direct batch host output fragment offsets");
|
||||||
|
offsetsByLane.push_back((*fragmentOffsets)[dim]);
|
||||||
|
}
|
||||||
|
mixedOffsets.push_back(allEqual(offsetsByLane)
|
||||||
|
? OpFoldResult(state.rewriter.getIndexAttr(offsetsByLane.front()))
|
||||||
|
: OpFoldResult(createLaneIndexedIndexValue(
|
||||||
|
state, *ownerClass, ArrayRef<int64_t>(offsetsByLane), fragments.front()->loc)));
|
||||||
|
}
|
||||||
|
|
||||||
|
state.hostReplacements[originalOutput] = ownerClass->op->getResult(resultIt->second);
|
||||||
|
state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||||
|
tensor::ParallelInsertSliceOp::create(state.rewriter,
|
||||||
|
fragments.front()->loc,
|
||||||
|
laneOperand,
|
||||||
|
*outputArg,
|
||||||
|
mixedOffsets,
|
||||||
|
getStaticIndexAttrs(state.rewriter, referenceSizes),
|
||||||
|
getStaticIndexAttrs(state.rewriter, referenceStrides));
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
state.rewriter.setInsertionPoint(ownerClass->body->getTerminator());
|
state.rewriter.setInsertionPoint(ownerClass->body->getTerminator());
|
||||||
@@ -5656,28 +5834,73 @@ LogicalResult finalizeProjectedHostOutputFragments(MaterializerState& state) {
|
|||||||
SmallVector<int64_t, 64> flatSizes;
|
SmallVector<int64_t, 64> flatSizes;
|
||||||
SmallVector<int64_t, 64> flatStrides;
|
SmallVector<int64_t, 64> flatStrides;
|
||||||
DenseMap<Value, int64_t> operandIndicesByValue;
|
DenseMap<Value, int64_t> operandIndicesByValue;
|
||||||
|
DenseSet<ClassId> emittedBatchForwarding;
|
||||||
|
|
||||||
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
|
for (PendingProjectedHostOutputFragment* fragmentRecord : fragments) {
|
||||||
Value operand = fragmentRecord->operand;
|
|
||||||
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
|
MaterializedClass& sourceClass = state.classes[fragmentRecord->sourceClass];
|
||||||
|
Value operand;
|
||||||
|
|
||||||
|
if (std::optional<Value> availableValue =
|
||||||
|
state.availableValues.lookup(state, fragmentRecord->producerKey, sourceClass.id)) {
|
||||||
|
operand = *availableValue;
|
||||||
|
} else if (fragmentRecord->sourceClass == sourceClass.id) {
|
||||||
|
operand = fragmentRecord->operand;
|
||||||
|
} else {
|
||||||
|
return sourceClass.op->emitError(
|
||||||
|
"projected host output fragment assembly is missing source-visible fragment operands before finalization");
|
||||||
|
}
|
||||||
|
|
||||||
if (fragmentRecord->sourceClass != ownerClass->id) {
|
if (fragmentRecord->sourceClass != ownerClass->id) {
|
||||||
if (sourceClass.isBatch || ownerClass->isBatch)
|
if (sourceClass.isBatch && !ownerClass->isBatch) {
|
||||||
return sourceClass.op->emitError(
|
if (!emittedBatchForwarding.insert(sourceClass.id).second) {
|
||||||
"projected host output fragment assembly requires scalarized cross-class operands before finalization");
|
std::optional<Value> localized = state.availableValues.lookup(state, fragmentRecord->producerKey, ownerClass->id);
|
||||||
MessageVector messages;
|
if (!localized)
|
||||||
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op,
|
return ownerClass->op->emitError(
|
||||||
sourceClass.cpus.front(),
|
"projected host output fragment assembly is missing forwarded batch fragments");
|
||||||
"projected host output source core id");
|
operand = *localized;
|
||||||
auto checkedTargetCpu = getCheckedCoreId(ownerClass->op,
|
} else {
|
||||||
ownerClass->cpus.front(),
|
SmallVector<ProducerKey, 8> forwardedKeys;
|
||||||
"projected host output target core id");
|
forwardedKeys.reserve(sourceClass.cpus.size());
|
||||||
if (failed(checkedSourceCpu) || failed(checkedTargetCpu))
|
Value forwardedPayload = fragmentRecord->operand;
|
||||||
return failure();
|
for (PendingProjectedHostOutputFragment* candidate : fragments) {
|
||||||
messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu);
|
if (candidate->sourceClass != sourceClass.id)
|
||||||
if (failed(appendSend(state, sourceClass, operand, messages, fragmentRecord->loc)))
|
continue;
|
||||||
return failure();
|
if (candidate->operand != forwardedPayload)
|
||||||
operand = appendReceive(state, *ownerClass, fragmentRecord->operandType, messages, fragmentRecord->loc);
|
return ownerClass->op->emitError(
|
||||||
|
"projected host output batch forwarding expects one shared batch payload per source class");
|
||||||
|
forwardedKeys.push_back(candidate->producerKey);
|
||||||
|
}
|
||||||
|
llvm::sort(forwardedKeys, [](ProducerKey lhs, ProducerKey rhs) {
|
||||||
|
return lhs.instance.laneStart < rhs.instance.laneStart;
|
||||||
|
});
|
||||||
|
if (failed(emitClassToClassCommunication(
|
||||||
|
state, sourceClass, *ownerClass, forwardedKeys, forwardedPayload, fragmentRecord->loc)))
|
||||||
|
return failure();
|
||||||
|
std::optional<Value> localized = state.availableValues.lookup(state, fragmentRecord->producerKey, ownerClass->id);
|
||||||
|
if (!localized)
|
||||||
|
return ownerClass->op->emitError(
|
||||||
|
"projected host output fragment assembly failed to recover forwarded batch fragment");
|
||||||
|
operand = *localized;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MessageVector messages;
|
||||||
|
auto checkedSourceCpu = getCheckedCoreId(sourceClass.op,
|
||||||
|
sourceClass.cpus.front(),
|
||||||
|
"projected host output source core id");
|
||||||
|
auto checkedTargetCpu = getCheckedCoreId(ownerClass->op,
|
||||||
|
ownerClass->cpus.front(),
|
||||||
|
"projected host output target core id");
|
||||||
|
if (failed(checkedSourceCpu) || failed(checkedTargetCpu))
|
||||||
|
return failure();
|
||||||
|
messages.append(state.nextChannelId++, *checkedSourceCpu, *checkedTargetCpu);
|
||||||
|
if (failed(appendSend(state, sourceClass, operand, messages, fragmentRecord->loc)))
|
||||||
|
return failure();
|
||||||
|
operand = appendReceive(state,
|
||||||
|
*ownerClass,
|
||||||
|
cast<RankedTensorType>(operand.getType()),
|
||||||
|
messages,
|
||||||
|
fragmentRecord->loc);
|
||||||
|
}
|
||||||
} else if (!ownerClass->isBatch) {
|
} else if (!ownerClass->isBatch) {
|
||||||
FailureOr<Value> localOperand = materializeTensorValueForMaterializedClassUse(
|
FailureOr<Value> localOperand = materializeTensorValueForMaterializedClassUse(
|
||||||
state,
|
state,
|
||||||
|
|||||||
@@ -6,15 +6,15 @@ from onnx import TensorProto
|
|||||||
|
|
||||||
# ONNX dtype -> (ctype, printf, ONNX_TYPE_*)
|
# ONNX dtype -> (ctype, printf, ONNX_TYPE_*)
|
||||||
DTYPES = {
|
DTYPES = {
|
||||||
TensorProto.FLOAT: ("float", "%g", "ONNX_TYPE_FLOAT"),
|
TensorProto.FLOAT: ("float", "%.9g", "ONNX_TYPE_FLOAT"),
|
||||||
TensorProto.DOUBLE: ("double", "%g", "ONNX_TYPE_DOUBLE"),
|
TensorProto.DOUBLE: ("double", "%.17g", "ONNX_TYPE_DOUBLE"),
|
||||||
TensorProto.INT64: ("int64_t", "%lld","ONNX_TYPE_INT64"),
|
TensorProto.INT64: ("int64_t", "%lld", "ONNX_TYPE_INT64"),
|
||||||
TensorProto.INT32: ("int32_t", "%d", "ONNX_TYPE_INT32"),
|
TensorProto.INT32: ("int32_t", "%d", "ONNX_TYPE_INT32"),
|
||||||
TensorProto.UINT8: ("uint8_t", "%u", "ONNX_TYPE_UINT8"),
|
TensorProto.UINT8: ("uint8_t", "%u", "ONNX_TYPE_UINT8"),
|
||||||
TensorProto.INT8: ("int8_t", "%d", "ONNX_TYPE_INT8"),
|
TensorProto.INT8: ("int8_t", "%d", "ONNX_TYPE_INT8"),
|
||||||
TensorProto.BOOL: ("uint8_t", "%u", "ONNX_TYPE_BOOL"), # stored as byte
|
TensorProto.BOOL: ("uint8_t", "%u", "ONNX_TYPE_BOOL"),
|
||||||
TensorProto.FLOAT16: ("uint16_t", "%u", "ONNX_TYPE_FLOAT16"), # raw 16-bit
|
TensorProto.FLOAT16: ("uint16_t", "%u", "ONNX_TYPE_FLOAT16"),
|
||||||
TensorProto.BFLOAT16:("uint16_t", "%u", "ONNX_TYPE_BFLOAT16"),
|
TensorProto.BFLOAT16:("uint16_t", "%u", "ONNX_TYPE_BFLOAT16"),
|
||||||
}
|
}
|
||||||
|
|
||||||
def esc(s): return s.replace("\\","\\\\").replace('"','\\"')
|
def esc(s): return s.replace("\\","\\\\").replace('"','\\"')
|
||||||
|
|||||||
Reference in New Issue
Block a user