Compare commits
156 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 852bef7605 | |||
| 237654dadf | |||
| 6d69600bc1 | |||
| aec80529ca | |||
| 8ddbbcecfa | |||
| 90c4339808 | |||
| 08870de1a6 | |||
| a34ac223c0 | |||
| 0fa10b4074 | |||
| e166ff7e1d | |||
| a70a8f77cf | |||
| 800c0c4316 | |||
| 1e9e61f5a9 | |||
| 27410207c4 | |||
| cbc9808229 | |||
| 69021d56aa | |||
| dc5edd032c | |||
| e33f517221 | |||
| f94b3d1020 | |||
| 20cf40c9ba | |||
| 37a59054a5 | |||
| 2a8faf9c6b | |||
| 01b9d03fc6 | |||
| 501e6c76f3 | |||
| 3c2667f11e | |||
| 0a5e73c3ea | |||
| 636310d0cb | |||
| 356be6ccc2 | |||
| b678e55d3c | |||
| ab63498f3f | |||
| 7c3943bd06 | |||
| c0238c0d06 | |||
| ff36729140 | |||
| cf93caecd5 | |||
| 2d5b03c08f | |||
| a41f694cf0 | |||
| 8bb0babf1b | |||
| 819d8af0f7 | |||
| 832bd7f1f7 | |||
| 82b44a6387 | |||
| 7fcc765d6e | |||
| f34698a2b6 | |||
| 1ab489fe0a | |||
| cbf7b235f1 | |||
| 00414dd1d9 | |||
| 783dffe553 | |||
| 874a2f53e6 | |||
| 4bdaa57656 | |||
| 1a5d7d2a3f | |||
| 013ae0ac2a | |||
| c6b02af7a9 | |||
| d2048bd394 | |||
| 158f0f0c54 | |||
| 532cac8246 | |||
| d609e84054 | |||
| addfc8a86e | |||
| 0f240af271 | |||
| bdc4ca33f3 | |||
| b79c333c6c | |||
| eea9261c7b | |||
| e8a08f6dd0 | |||
| 4855a2e105 | |||
| 3a7a832198 | |||
| 48ca6bd28d | |||
| f595cc6ffd | |||
| c734f1b37e | |||
| b79ce8eeaa | |||
| 76a37e198f | |||
| 7f3c7464b4 | |||
| c77ffa9c56 | |||
| 495186503c | |||
| 2c1da813b5 | |||
| 8337a11ce9 | |||
| d136136d22 | |||
| 074eb183c7 | |||
| 43ed3914b8 | |||
| 6aaf1c0870 | |||
| fe35b3ed43 | |||
| 90a9339686 | |||
| a50e77ff38 | |||
| f56c4159b5 | |||
| 5637c861b4 | |||
| 94157a8404 | |||
| 68a3521978 | |||
| a103ba328b | |||
| e263e05f56 | |||
| 34c29fdec4 | |||
| aa088e2ba5 | |||
| 2836e759ab | |||
| 8071ebab0b | |||
| f1602c0550 | |||
| de0a2f4561 | |||
| 1c4a5bde76 | |||
| 78242e2887 | |||
| fe244d5aa1 | |||
| d09e76c8f9 | |||
| c5e608fa5b | |||
| 43f3ccdd21 | |||
| 8d95c604a6 | |||
| 55eda487dc | |||
| 061139aefb | |||
| ea61540e08 | |||
| 324178cba8 | |||
| e71ba07cd5 | |||
| 64a3805619 | |||
| 9f9e7c0892 | |||
| 03eab42971 | |||
| c15aba5d96 | |||
| 4821e8a55e | |||
| 88bb223bb1 | |||
| 623ee62a04 | |||
| ad56888b0b | |||
| f993840641 | |||
| 0c7db55a24 | |||
| 41de3cb150 | |||
| 4f3570520c | |||
| 628dc630a4 | |||
| 80a7298552 | |||
| 8ad504fcdf | |||
| e6f442c5d2 | |||
| f6b97b3813 | |||
| 26317ea7d0 | |||
| 909c4acfdd | |||
| feaff820e1 | |||
| 1e279ae9bb | |||
| 57f0cca8c0 | |||
| 5ff364027b | |||
| b1272d2283 | |||
| 58e6587697 | |||
| f6c8cc4aa5 | |||
| 566630b99a | |||
| 74931ad75b | |||
| f2fe147961 | |||
| 7bb58e80de | |||
| b2dc9c38b6 | |||
| 3cb6a1abc5 | |||
| 285773fa55 | |||
| bdacb9871d | |||
| 5b9bb0c191 | |||
| f789954ad7 | |||
| b6ba1e4fea | |||
| 717ad160cd | |||
| 905fa9f9a7 | |||
| 62b0a6e19d | |||
| b605585b1f | |||
| 08b0fcd850 | |||
| 9dccc2c701 | |||
| 5c839e62c1 | |||
| 15e8edb9c4 | |||
| 951baca106 | |||
| fc5bccb487 | |||
| 49dea15b95 | |||
| 5545b0f672 | |||
| cff929a083 | |||
| 89b3501aa8 | |||
| 412ca957f6 |
+11
-2
@@ -1,5 +1,14 @@
|
||||
.zed
|
||||
.idea
|
||||
**/.vscode
|
||||
|
||||
.claude
|
||||
AGENTS.md
|
||||
build
|
||||
.codex
|
||||
|
||||
CMakeUserPresets.json
|
||||
|
||||
build_*
|
||||
compile.sh
|
||||
pimcomp_utils/*
|
||||
|
||||
**/__*
|
||||
|
||||
+1
-1
@@ -3,4 +3,4 @@
|
||||
url = https://github.com/onnx/onnx-mlir.git
|
||||
[submodule "backend-simulators/pim/pimsim-nn"]
|
||||
path = backend-simulators/pim/pimsim-nn
|
||||
url = https://github.com/wangxy-2000/pimsim-nn.git
|
||||
url = https://github.com/HEAPLab/pimsim-nn.git
|
||||
|
||||
@@ -0,0 +1,210 @@
|
||||
* Always read the full README.md before doing anything
|
||||
* Build commands:
|
||||
* `cmake --build ./build_release`
|
||||
* `cmake --build ./build_debug`
|
||||
* Never use `ninja` directly: it bypasses cmake's configuration and invalidates the build cache
|
||||
* Always try the release build first before building with the debug version
|
||||
* Use the debug build only when it is useful to obtain a clear stack trace with symbols, inspect names, place breakpoints, or test a small case interactively
|
||||
* The debug build is very slow, so use it only on small fast tests such as operation validations, not on network validations
|
||||
|
||||
# Core engineering philosophy
|
||||
|
||||
* Clean architecture matters as much as making the immediate test pass
|
||||
* Prefer fixes that preserve clear ownership boundaries, explicit invariants, and simple dataflow
|
||||
* Do not stack compensating fixes on top of earlier mistakes. If the current approach is becoming messy, stop and explain why
|
||||
* A correct fix should usually make the responsible producer, resolver, verifier, or lowering own the behavior directly
|
||||
* Avoid late repair passes, defensive cleanup, or broad rewrites when a cleaner owner-side fix is possible
|
||||
* Do not hide an upstream modeling bug by normalizing it later in the pipeline. Fix the producer when the producer owns the invariant
|
||||
* Prefer patterns/rewrites for local IR canonicalization. Use module walks only when pass-level structural analysis genuinely requires them
|
||||
* Prefer compact, structured designs over long case-by-case implementations
|
||||
|
||||
# Think before coding
|
||||
|
||||
* State assumptions explicitly before implementing when they affect the design
|
||||
* If multiple interpretations exist, present them instead of silently choosing one
|
||||
* If a simpler approach exists, say so and prefer it unless there is a clear reason not to
|
||||
* If something is unclear, stop, name what is confusing, and ask
|
||||
* If the requested or obvious approach would make the architecture worse, push back and propose a cleaner alternative
|
||||
|
||||
# Code changes
|
||||
|
||||
* Keep changes minimal and localized to the relevant parts of the code
|
||||
* Preserve the existing naming conventions and coding style used in the surrounding code
|
||||
* Keep code easy to read, well organized, and suitable for future extensibility
|
||||
* A function must not exceed roughly 200/250 lines. If a change pushes a function beyond that, extract focused helpers
|
||||
* Prefer clear naming and structure over comments. Add comments only when they materially improve clarity
|
||||
* Do not rename symbols, move files, or restructure modules unless that is necessary for the requested change
|
||||
* Avoid duplicate ad-hoc logic. If the same concept appears in multiple places, consider whether it deserves a shared helper/API
|
||||
* When adding a helper or API, ask:
|
||||
* Could this be useful to another component now
|
||||
* Is another component already implementing the same idea differently
|
||||
* Is this likely to be needed by a future adjacent component
|
||||
* What is the narrowest useful abstraction
|
||||
* What is the correct ownership level for this API
|
||||
* If a shared API is justified, place it at the lowest clean layer that can be used by all relevant consumers without creating dependency cycles or leaking policy across layers
|
||||
* If an existing component should use a newly introduced shared API, refactor that component in the same patch when doing so is directly related and reduces duplication
|
||||
* Do not create broad frameworks just because a helper might someday be useful. Shared APIs should encode a real reusable concept, not speculative generality
|
||||
* If the reusable abstraction is plausible but not clearly needed yet, keep the code local and mention the possible future extraction separately
|
||||
|
||||
# Avoid case-listing designs
|
||||
|
||||
* Avoid solving problems with large chains of `if`/`else`, switches, or repeated special cases that enumerate every possible situation
|
||||
* Long case listings tend to overfit the current tests, grow the codebase, and hide the underlying abstraction
|
||||
* When you see a growing list of special cases, stop and look for the shared concept, data model, interface, or normalization step that would make the cases collapse
|
||||
* Prefer table-driven logic, traits/interfaces, small reusable predicates, structured dispatch, or producer-side normalization when they express the invariant more directly
|
||||
* A few explicit cases are fine when the domain is genuinely small and closed
|
||||
* If the list is likely to grow, refactor toward a cleaner and more compact design instead of adding another branch
|
||||
* When keeping a case list is the pragmatic choice, explain why the domain is closed or why a broader abstraction would be premature
|
||||
|
||||
# Ownership and invariants
|
||||
|
||||
Before implementing, identify the owner of the behavior:
|
||||
|
||||
* A producer should emit IR/data that satisfies the contract of the next stage
|
||||
* A lowering should make representation changes explicit and semantically correct
|
||||
* A resolver should resolve existing structure without silently changing semantics
|
||||
* A verifier should reject invalid states with bounded, actionable diagnostics
|
||||
* Codegen should assume verified invariants and fail clearly if they are violated
|
||||
|
||||
When fixing a bug:
|
||||
|
||||
* State the invariant that was violated
|
||||
* State which component should own that invariant
|
||||
* Fix that component directly
|
||||
* Avoid fixes that merely mask the violation later in the pipeline
|
||||
* Add or preserve verification if the invariant is important enough to regress
|
||||
|
||||
# Refactor and API policy
|
||||
|
||||
You may propose or implement a refactor when:
|
||||
|
||||
* the local fix would duplicate logic
|
||||
* the local fix would violate a layer boundary
|
||||
* the bug exists because responsibility is assigned to the wrong component
|
||||
* multiple components already implement ad-hoc variants of the same concept
|
||||
* a shared helper/API would make the code smaller, clearer, and easier to maintain
|
||||
* existing callers can be migrated cleanly without broad churn
|
||||
* the current implementation is turning into a long list of special cases instead of a structured solution
|
||||
|
||||
When proposing or implementing a refactor:
|
||||
|
||||
* Explain what responsibility is being moved or shared
|
||||
* Justify why the new location is the right ownership level
|
||||
* Keep the API narrow and named after the concept or invariant it represents
|
||||
* Migrate directly related existing users when that improves compactness and consistency
|
||||
* Separate changes required for correctness from optional cleanup
|
||||
* Avoid unrelated renames, formatting changes, or module moves
|
||||
* Do not expand a justified refactor beyond directly related callers
|
||||
|
||||
Do not refactor when:
|
||||
|
||||
* the issue is truly local and a local fix is clearer
|
||||
* the abstraction would have only one user and no clear adjacent use
|
||||
* the abstraction would mix policies from different layers
|
||||
* the refactor would affect unrelated behavior
|
||||
* the refactor is mainly aesthetic
|
||||
|
||||
# Working style
|
||||
|
||||
* Infer style and conventions from the existing code before introducing new patterns
|
||||
* When several implementation options are possible, prefer the simplest one that fits the current architecture and minimizes churn
|
||||
* Push back when the requested or obvious fix would make the architecture worse
|
||||
* If a cleaner fix requires a small refactor or shared helper/API, propose it explicitly and justify it
|
||||
* Avoid broad refactors unless explicitly requested or clearly necessary for correctness and maintainability
|
||||
* When tests fail, bucket failures by likely root cause and separate patch-related failures from pre-existing or out-of-scope failures
|
||||
|
||||
# Simplicity first
|
||||
|
||||
* Minimum code that solves the problem cleanly. Nothing speculative
|
||||
* No features beyond what was asked
|
||||
* No error handling for impossible scenarios
|
||||
* If you write 200 lines and it could be 50, rewrite it
|
||||
* Ask: “Would a senior engineer say this is overcomplicated?” If yes, simplify
|
||||
* Prefer direct, explicit code over generic machinery unless the generic machinery clearly reduces duplication and preserves boundaries
|
||||
|
||||
# Fallbacks and defaults
|
||||
|
||||
* Avoid silent fallback behavior when the semantic category is unknown
|
||||
* Do not treat “unknown” as “safe” unless the codebase already defines that convention
|
||||
* If a value cannot be classified, either preserve the existing behavior deliberately or fail with a clear diagnostic
|
||||
* When adding a fallback, state why it is semantically valid and what invariant makes it safe
|
||||
|
||||
# Surgical changes
|
||||
|
||||
* Touch only what you must
|
||||
* Clean up only the mess introduced by your own change
|
||||
* Do not “improve” adjacent code, comments, or formatting
|
||||
* Match existing style, even if you would personally do it differently
|
||||
* If you notice unrelated dead code, bad abstractions, or fragile design, mention it separately. Do not delete or rewrite it unless asked
|
||||
* When your changes create orphans, remove imports, variables, functions, or files made unused by your change
|
||||
* Every changed line should trace directly to the requested fix, a required cleanup, or a justified reuse/refactor decision
|
||||
|
||||
# Diagnostics and verification
|
||||
|
||||
* Use existing bounded diagnostic mechanisms for pass-level verification or codegen failures
|
||||
* Do not emit unbounded repeated diagnostics from loops or parallel workers
|
||||
* Diagnostics should identify the violated invariant and the relevant value/op when useful
|
||||
* Verifiers should reject invalid states, not repair them
|
||||
* Codegen should not compensate for invalid IR/data unless codegen is the owner of that invariant
|
||||
* Do not make failing tests pass by weakening verifiers, assertions, or diagnostics unless the check itself is proven wrong
|
||||
* If a check is too strict, explain the valid case it rejects and update the invariant accordingly
|
||||
* Prefer fixing invalid IR/data producers over relaxing consumers
|
||||
* If adding diagnostics only for debugging, remove them or cap them before finalizing
|
||||
|
||||
# Temporary debugging code
|
||||
|
||||
* Temporary diagnostics, dumps, assertions, and debug-only helpers must be removed or intentionally converted into bounded permanent diagnostics before finalizing
|
||||
* If debug instrumentation remains, explain why it is useful as permanent infrastructure
|
||||
* Do not leave noisy validation output behind
|
||||
|
||||
# Performance awareness
|
||||
|
||||
* Avoid algorithmic regressions in compiler passes, especially repeated full-module walks, repeated expensive analyses, or per-op recomputation inside nested loops
|
||||
* If a change adds a walk, cache, analysis, or structural traversal, justify why it is needed
|
||||
* For hot paths, prefer preserving existing asymptotic behavior unless a better structure is part of the requested change
|
||||
* If performance may change, mention the expected impact and suggest a targeted timing check
|
||||
|
||||
# Goal-driven execution
|
||||
|
||||
For multi-step tasks, state a brief plan:
|
||||
|
||||
1. [Step] → verify: [check]
|
||||
2. [Step] → verify: [check]
|
||||
3. [Step] → verify: [check]
|
||||
|
||||
Define success criteria before implementing:
|
||||
|
||||
* For bug fixes, success means reproducing or identifying the failure, fixing the responsible owner, and verifying the targeted case
|
||||
* For refactors, success means preserving behavior while making ownership, reuse, or structure cleaner
|
||||
* For validation changes, success means checking both valid and invalid cases when applicable
|
||||
|
||||
Transform tasks into verifiable goals:
|
||||
|
||||
* “Fix the bug” → identify the invariant, reproduce the failure, fix the owner, verify the targeted case
|
||||
* “Add validation” → write or identify tests for invalid inputs, then make them pass/fail as expected
|
||||
* “Refactor X” → preserve behavior before and after, then run relevant tests
|
||||
|
||||
# Final self-review
|
||||
|
||||
Before reporting completion, check:
|
||||
|
||||
* Did I fix the owner of the invariant rather than masking the issue downstream
|
||||
* Did I avoid broad case lists and ad-hoc special handling
|
||||
* Did I introduce a helper/API only at the right ownership level
|
||||
* Did I migrate directly related duplicate logic when doing so improves compactness
|
||||
* Did I avoid weakening verifiers or assertions unnecessarily
|
||||
* Did I remove temporary debugging code or make it bounded and intentional
|
||||
* Did I avoid unrelated formatting, renames, or cleanup
|
||||
* Did I consider performance impact for added walks, analyses, caches, or repeated computations
|
||||
* Did I run the required build/test commands
|
||||
* Did I clearly report remaining failures or risks
|
||||
|
||||
When reporting back:
|
||||
|
||||
* Say what changed
|
||||
* Say what was verified
|
||||
* Say what remains
|
||||
* When showing code in chat, make it easy to copy-paste into the codebase
|
||||
* Keep outputs focused on the changed parts
|
||||
* List bad practices, fragile assumptions, or cleaner alternatives separately
|
||||
* If a change is intentionally pragmatic rather than architecturally ideal, say so and explain the tradeoff
|
||||
+92
-24
@@ -3,31 +3,99 @@ cmake_minimum_required(VERSION 3.20.0)
|
||||
|
||||
project(raptor)
|
||||
|
||||
# Add symlink to PIM as accelerator in onnx-mlir
|
||||
function(raptor_ensure_symlink link_path target_path)
|
||||
get_filename_component(link_parent "${link_path}" DIRECTORY)
|
||||
# Materialize a CMake shim directory
|
||||
function(raptor_write_external_cmake_shim shim_dir external_source_dir description)
|
||||
get_filename_component(real_external_source_dir "${external_source_dir}" REALPATH)
|
||||
file(RELATIVE_PATH relative_external_source_dir "${shim_dir}" "${real_external_source_dir}")
|
||||
|
||||
if(NOT EXISTS "${link_parent}")
|
||||
message(FATAL_ERROR "Directory not found: ${link_parent}")
|
||||
endif()
|
||||
|
||||
if(NOT EXISTS "${link_path}")
|
||||
message(STATUS "Creating symlink ${link_path} -> ${target_path}")
|
||||
file(CREATE_LINK
|
||||
"${target_path}"
|
||||
"${link_path}"
|
||||
SYMBOLIC
|
||||
if (NOT EXISTS "${real_external_source_dir}/CMakeLists.txt")
|
||||
message(FATAL_ERROR
|
||||
"External CMake source directory not found or missing CMakeLists.txt:\n"
|
||||
" ${real_external_source_dir}"
|
||||
)
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
if (IS_SYMLINK "${shim_dir}")
|
||||
message(STATUS "Removing old full-directory symlink: ${shim_dir}")
|
||||
file(REMOVE "${shim_dir}")
|
||||
endif ()
|
||||
|
||||
if (EXISTS "${shim_dir}" AND NOT IS_DIRECTORY "${shim_dir}")
|
||||
message(FATAL_ERROR "Expected directory or absent path, got file: ${shim_dir}")
|
||||
endif ()
|
||||
|
||||
file(MAKE_DIRECTORY "${shim_dir}")
|
||||
|
||||
set(shim_file "${shim_dir}/CMakeLists.txt")
|
||||
set(shim_contents
|
||||
"get_filename_component(raptor_external_source_dir
|
||||
\"\${CMAKE_CURRENT_LIST_DIR}/${relative_external_source_dir}\"
|
||||
REALPATH
|
||||
)
|
||||
add_subdirectory(
|
||||
\"\${raptor_external_source_dir}\"
|
||||
\"\${CMAKE_CURRENT_BINARY_DIR}/raptor-external\"
|
||||
)
|
||||
if (DEFINED PIM_ENABLED)
|
||||
set(PIM_ENABLED \"\${PIM_ENABLED}\" PARENT_SCOPE)
|
||||
endif ()
|
||||
"
|
||||
)
|
||||
|
||||
if (EXISTS "${shim_file}")
|
||||
file(READ "${shim_file}" old_contents)
|
||||
else ()
|
||||
set(old_contents "")
|
||||
endif ()
|
||||
|
||||
if (NOT old_contents STREQUAL shim_contents)
|
||||
file(WRITE "${shim_file}" "${shim_contents}")
|
||||
message(STATUS "Wrote CMake shim for ${description}: ${shim_file}")
|
||||
else ()
|
||||
message(STATUS "CMake shim already up to date for ${description}")
|
||||
endif ()
|
||||
|
||||
# Mirror the external tree's first-level entries into the shim directory
|
||||
# so legacy includes like src/Accelerators/PIM/Compiler/... keep working.
|
||||
file(GLOB children RELATIVE "${real_external_source_dir}" "${real_external_source_dir}/*")
|
||||
|
||||
foreach (child IN LISTS children)
|
||||
if (child STREQUAL "CMakeLists.txt")
|
||||
continue()
|
||||
endif ()
|
||||
|
||||
set(real_child "${real_external_source_dir}/${child}")
|
||||
set(shim_child "${shim_dir}/${child}")
|
||||
|
||||
if (IS_SYMLINK "${shim_child}")
|
||||
file(READ_SYMLINK "${shim_child}" existing_link_target)
|
||||
if (existing_link_target STREQUAL real_child)
|
||||
continue()
|
||||
endif ()
|
||||
file(REMOVE_RECURSE "${shim_child}")
|
||||
elseif (EXISTS "${shim_child}")
|
||||
# Do not delete real files/directories. This protects the generated shim.
|
||||
continue()
|
||||
endif ()
|
||||
|
||||
file(CREATE_LINK
|
||||
"${real_child}"
|
||||
"${shim_child}"
|
||||
SYMBOLIC
|
||||
)
|
||||
endforeach ()
|
||||
endfunction()
|
||||
|
||||
raptor_ensure_symlink(
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
|
||||
raptor_write_external_cmake_shim(
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
|
||||
"PIM accelerator"
|
||||
)
|
||||
raptor_ensure_symlink(
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
|
||||
|
||||
raptor_write_external_cmake_shim(
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
|
||||
"PIM accelerator tests"
|
||||
)
|
||||
|
||||
# Patch onnx-mlir sources for PIM accelerator support.
|
||||
@@ -38,21 +106,21 @@ function(raptor_apply_patch file_path anchor replacement description)
|
||||
|
||||
# Already applied – replacement text is present
|
||||
string(FIND "${contents}" "${replacement}" already_applied_pos)
|
||||
if(NOT already_applied_pos EQUAL -1)
|
||||
if (NOT already_applied_pos EQUAL -1)
|
||||
message(STATUS "Patch already applied: ${description}")
|
||||
return()
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
# Anchor must exist for the patch to be applicable
|
||||
string(FIND "${contents}" "${anchor}" anchor_pos)
|
||||
if(anchor_pos EQUAL -1)
|
||||
if (anchor_pos EQUAL -1)
|
||||
message(FATAL_ERROR
|
||||
"Patch anchor not found – onnx-mlir may have changed.\n"
|
||||
" Patch : ${description}\n"
|
||||
" File : ${file_path}\n"
|
||||
" Anchor: ${anchor}"
|
||||
)
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
string(REPLACE "${anchor}" "${replacement}" patched "${contents}")
|
||||
file(WRITE "${file_path}" "${patched}")
|
||||
|
||||
@@ -1,65 +1,321 @@
|
||||
# Raptor
|
||||
|
||||
Raptor is a domain-specific MLIR compiler for neural networks in ONNX format,
|
||||
targeting in-memory computing / processing-in-memory (PIM) architectures. It
|
||||
extends ONNX-MLIR with a PIM accelerator and progressively lowers ONNX-MLIR
|
||||
through custom MLIR dialects to simulator artifacts.
|
||||
|
||||
The current target is the PIM simulator stack under `backend-simulators/pim`.
|
||||
Raptor emits binary per-core `.pim` instruction files by default, plus
|
||||
`memory.bin`, `config.json`, and weight binaries. It can also emit per-core JSON
|
||||
instruction files with `--pim-emit-json`.
|
||||
|
||||
## Overview
|
||||
|
||||
PIM architectures perform most computation directly in memory. The supported
|
||||
target models a chip with:
|
||||
- shared host memory,
|
||||
- multiple PIM cores,
|
||||
- ReRAM crossbars for vector-matrix / matrix-vector work,
|
||||
- explicit communication between cores,
|
||||
- no hardware branch or loop support in emitted simulator code.
|
||||
|
||||
Because repeated work such as convolutions is eventually made explicit, emitted
|
||||
instruction counts can grow quickly. Most compiler work therefore focuses on
|
||||
lowering, scheduling, memory layout, and code-generation optimizations.
|
||||
|
||||
### Targets and simulators
|
||||
|
||||
- `backend-simulators/pim/pim-simulator` is the in-tree Rust functional
|
||||
simulator used by validation. It reads Raptor's `pim/` artifact directory and
|
||||
compares simulator output against native ONNX-MLIR execution.
|
||||
- `backend-simulators/pim/pimsim-nn` is the performance simulator submodule.
|
||||
The helper scripts in `pimcomp_utils/` are for comparison with PIMCOMP-NN and
|
||||
contain local paths; treat them as local utilities, not portable workflows.
|
||||
|
||||
## Compilation pipeline
|
||||
|
||||
The PIM sources live under `src/PIM` and tests under `test/PIM`. CMake exposes
|
||||
them to ONNX-MLIR through generated shim directories under
|
||||
`onnx-mlir/src/Accelerators/PIM` and `onnx-mlir/test/accelerators/PIM`.
|
||||
|
||||
High-level lowering flow:
|
||||
|
||||
```
|
||||
ONNX-MLIR -> Spatial -> Pim (tensor) -> Pim (bufferized) -> PIM artifacts
|
||||
```
|
||||
|
||||
1. **ONNX -> Spatial** (`src/PIM/Conversion/ONNXToSpatial`).
|
||||
Lowers supported ONNX ops into the `spat` dialect
|
||||
(`src/PIM/Dialect/Spatial`). Conversion patterns are split by op family under
|
||||
`Patterns/{Math,NN,Tensor}` and currently cover Conv, Gemm, MatMul,
|
||||
elementwise Add/Mul/Div, ReduceMean, pooling, Relu, Sigmoid, Softmax,
|
||||
Concat, Gather, Reshape, Resize, and Split.
|
||||
|
||||
2. **Merge compute nodes**
|
||||
(`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`).
|
||||
Builds a compute graph, schedules it with the PEFT scheduler, and materializes
|
||||
the merge schedule into Spatial IR. Supporting scheduling code lives under
|
||||
`MergeComputeNodes/Scheduling`.
|
||||
|
||||
3. **Spatial -> Pim** (`src/PIM/Conversion/SpatialToPim`).
|
||||
Lowers Spatial operations to the `pim` dialect (`src/PIM/Dialect/Pim`),
|
||||
including `pim.core`, `pim.core_batch`, communication, tensor packing, global
|
||||
tensor materialization, and return-path normalization.
|
||||
|
||||
4. **Bufferization** (`src/PIM/Dialect/Pim/Transforms/Bufferization`).
|
||||
Converts tensor-semantics PIM IR into memref-semantics PIM IR using MLIR's
|
||||
bufferization interfaces.
|
||||
|
||||
5. **Static memory coalescing**
|
||||
(`src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing`).
|
||||
Reuses compatible local memref allocations inside PIM cores before codegen.
|
||||
|
||||
6. **PIM code generation** (`src/PIM/Pass/PimCodegen` and
|
||||
`src/PIM/Compiler`).
|
||||
Folds host constants, materializes remaining host constants, verifies PIM IR,
|
||||
emits `.pim` core files, writes weights, and writes `memory.bin` /
|
||||
`config.json`.
|
||||
|
||||
Supporting pieces:
|
||||
- `src/PIM/Common` - shared IR, filesystem, diagnostics, reports, and utility
|
||||
helpers.
|
||||
- `src/PIM/Compiler` - PIM compiler options, memory/address planning, binary
|
||||
instruction format, artifact writing, weight emission, and codegen entry
|
||||
points.
|
||||
- `src/PIM/Conversion/SpatialToGraphviz` - optional Spatial graphviz conversion
|
||||
pass.
|
||||
- `src/PIM/Pass` - pass registration and auxiliary passes.
|
||||
- `src/PIM/PimAccelerator.{cpp,hpp}` - ONNX-MLIR accelerator entry point.
|
||||
|
||||
## Key compiler options
|
||||
|
||||
Pass these to `onnx-mlir` when compiling for PIM:
|
||||
|
||||
- `--maccel=PIM` - select the PIM accelerator.
|
||||
- `--EmitSpatial`, `--EmitPim`, `--EmitPimBufferized`,
|
||||
`--EmitPimCodegen` - stop the PIM pipeline at the requested stage. The PIM
|
||||
default is `--EmitPimCodegen`.
|
||||
- `--core-count=<N>` - required positive core count for PIM compilation.
|
||||
- `--crossbar-size=<N>` - crossbar width/height. Default in code is `2`.
|
||||
- `--crossbar-count=<N>` - crossbars per core. Default in code is `256`.
|
||||
- `--pim-merge-scheduler=peft` - merge scheduler. `peft` is the only accepted
|
||||
value in the current code.
|
||||
- `--pim-only-codegen` - assume input is already bufferized PIM IR and only run
|
||||
the codegen tail.
|
||||
- `--pim-emit-json` - also emit `core_*.json` instruction files alongside
|
||||
`core_*.pim`.
|
||||
- `--use-experimental-conv-impl` - use the alternate convolution lowering.
|
||||
- `--ignore-concat-error` - soft-fail a ConcatOp corner case.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
./build_release/Release/bin/onnx-mlir model.onnx -o /tmp/raptor/model \
|
||||
--maccel=PIM --EmitPimCodegen \
|
||||
--crossbar-size=2048 --crossbar-count=256 --core-count=1000
|
||||
```
|
||||
|
||||
This writes PIM artifacts under `/tmp/raptor/pim/`.
|
||||
|
||||
## Validation
|
||||
|
||||
Functional validation lives in `validation/`. It compiles ONNX models, builds a
|
||||
native ONNX-MLIR reference runner, generates random inputs, runs Raptor, runs
|
||||
the Rust PIM simulator, and compares outputs.
|
||||
|
||||
Python dependencies used by the validation scripts are `numpy`, `onnx`, and
|
||||
`colorama`. The simulator requires the Rust toolchain.
|
||||
|
||||
Per-operation validation from the repository root:
|
||||
|
||||
```bash
|
||||
python3 validation/validate.py \
|
||||
--raptor-path build_release/Release/bin/onnx-mlir \
|
||||
--onnx-include-dir onnx-mlir/include \
|
||||
--core-count 1000
|
||||
```
|
||||
|
||||
Validate one network or a subset by pointing `--operations-dir` at any directory
|
||||
containing `.onnx` files:
|
||||
|
||||
```bash
|
||||
python3 validation/validate.py \
|
||||
--raptor-path build_release/Release/bin/onnx-mlir \
|
||||
--onnx-include-dir onnx-mlir/include \
|
||||
--operations-dir validation/networks/yolo11n/depth_04 \
|
||||
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
|
||||
```
|
||||
|
||||
Useful validation options:
|
||||
- `--simulator-dir <path>` - override the auto-detected
|
||||
`backend-simulators/pim/pim-simulator` path.
|
||||
- `--threshold <float>` - maximum allowed per-element output difference.
|
||||
- `--seed <int>` - RNG seed for generated inputs.
|
||||
- `--command-timeout-seconds <float>` - timeout for compiler, runner, and
|
||||
simulator subprocesses.
|
||||
- `--verbose` - print subprocess logs and average PIM pass timings.
|
||||
- `--clean` - remove generated validation artifacts and exit.
|
||||
|
||||
Each validation run writes artifacts in the model workspace, for example under
|
||||
`validation/operations/gemm/small/`:
|
||||
- `inputs/` - generated input CSV files.
|
||||
- `outputs/` - native ONNX-MLIR reference outputs.
|
||||
- `raptor/` - compiler artifacts, including `*.onnx.mlir`, dialect dumps under
|
||||
`dialects/`, reports under `reports/`, and final PIM artifacts under `pim/`.
|
||||
- `runner/` - generated reference runner source, build tree, and shared library.
|
||||
- `simulation/out.bin` - raw simulator output used for comparison.
|
||||
|
||||
The compiler currently dumps dialect snapshots such as `spatial0.mlir`,
|
||||
`spatial1_dcp_merged.mlir`, `pim0.mlir`, `pim1_buff.mlir`,
|
||||
`pim2_coalesced.mlir`, and `pim3_folded.mlir` when an output directory is
|
||||
available.
|
||||
|
||||
To rerun the simulator manually with tracing after validation has produced a
|
||||
`raptor/pim/` directory:
|
||||
|
||||
```bash
|
||||
cd backend-simulators/pim/pim-simulator
|
||||
cargo run --no-default-features --features tracing --release \
|
||||
--package pim-simulator --bin pim-simulator -- \
|
||||
-f /path/to/workspace/raptor/pim \
|
||||
-o /path/to/workspace/simulation/out.bin \
|
||||
-d <addr0>,<size0>,<addr1>,<size1>,...
|
||||
```
|
||||
|
||||
With `--features tracing`, the simulator writes per-core traces as
|
||||
`TraceCore0`, `TraceCore1`, ... next to `out.bin`. The validator normally
|
||||
computes the `-d` ranges from `raptor/pim/config.json` and model output shapes.
|
||||
|
||||
Available validation networks under `validation/networks/`: `vgg16`,
|
||||
`yolo11n`, `yolo11nv2`.
|
||||
|
||||
Available operation suites under `validation/operations/`: `add`, `concat`,
|
||||
`conv`, `div`, `gather`, `gemm`, `gemv`, `matmul`, `mul`, `pool`,
|
||||
`reduce_mean`, `relu`, `reshape`, `resize`, `sigmoid`, `softmax`, `split`.
|
||||
|
||||
Generated operation tests can be regenerated with:
|
||||
|
||||
```bash
|
||||
python3 validation/operations/gen_tests.py
|
||||
```
|
||||
|
||||
## Build
|
||||
|
||||
Initialize submodules first:
|
||||
|
||||
```bash
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
The project follows ONNX-MLIR's build requirements. The CI workflow documents
|
||||
the currently used versions and setup:
|
||||
- CMake 4.3.0 in CI,
|
||||
- LLVM/MLIR checked out under `onnx-mlir/llvm-project`,
|
||||
- Protobuf `v34.0`,
|
||||
- Rust stable for `pim-simulator`,
|
||||
- Python packages `numpy`, `onnx`, `colorama` for validation.
|
||||
|
||||
### Protobuf
|
||||
|
||||
Use the following commands to install protobuf:
|
||||
```
|
||||
Install Protobuf if your system does not already provide a compatible version:
|
||||
|
||||
```bash
|
||||
git clone --depth 1 --branch v34.0 https://github.com/protocolbuffers/protobuf
|
||||
cd protobuf
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release
|
||||
ninja
|
||||
sudo ninja install
|
||||
cmake -S protobuf -B protobuf/build -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-Dprotobuf_BUILD_TESTS=OFF
|
||||
cmake --build protobuf/build
|
||||
sudo cmake --install protobuf/build
|
||||
```
|
||||
|
||||
You can now remove the protobuf repo directory with:
|
||||
```
|
||||
cd ../..
|
||||
You can then remove the temporary checkout:
|
||||
|
||||
```bash
|
||||
rm -rf protobuf
|
||||
```
|
||||
|
||||
### Mlir
|
||||
### MLIR
|
||||
|
||||
Follow the first part of instructions [here](onnx-mlir/docs/BuildOnLinuxOSX.md) to build mlir.
|
||||
Follow the ONNX-MLIR instructions in
|
||||
`onnx-mlir/docs/BuildOnLinuxOSX.md` to build LLVM/MLIR. The local Raptor build
|
||||
expects `MLIR_DIR` to point at the MLIR CMake package, for example:
|
||||
|
||||
Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor
|
||||
|
||||
Moreover, if compiling with build type debug, it is also suggested to use
|
||||
mold as linker (you will need to install it if you don't have it already)
|
||||
to reduce memory usage during linking. You can use it by setting the options:
|
||||
```
|
||||
-DLLVM_USE_LINKER=mold
|
||||
```bash
|
||||
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build_release/lib/cmake/mlir
|
||||
```
|
||||
|
||||
If your LLVM build directory is named `build` instead of `build_release`, adjust
|
||||
the path accordingly.
|
||||
|
||||
### Raptor
|
||||
|
||||
Use the following commands to build Raptor.
|
||||
Configure a release build:
|
||||
|
||||
Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor.
|
||||
|
||||
Also in this case, it is suggested to use mold as linker to reduce link time and memory usage,
|
||||
setting the options:
|
||||
```
|
||||
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
|
||||
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold" \
|
||||
-DCMAKE_MODULE_LINKER_FLAGS="-fuse-ld=mold"
|
||||
```
|
||||
|
||||
```
|
||||
git submodule update --init --recursive
|
||||
|
||||
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build/lib/cmake/mlir
|
||||
mkdir build && cd build
|
||||
cmake .. -G Ninja \
|
||||
```bash
|
||||
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build_release/lib/cmake/mlir
|
||||
cmake -S . -B build_release -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DONNX_MLIR_ACCELERATORS=PIM \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DMLIR_DIR=${MLIR_DIR}
|
||||
cmake --build .
|
||||
```
|
||||
|
||||
If the build fails because of protobuf missing uint definitions,
|
||||
just patch the problematic files by adding ```#include <cstdint>``` to their includes.
|
||||
Configure a debug build similarly:
|
||||
|
||||
```bash
|
||||
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build_debug/lib/cmake/mlir
|
||||
cmake -S . -B build_debug -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Debug \
|
||||
-DONNX_MLIR_ACCELERATORS=PIM \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DMLIR_DIR=${MLIR_DIR}
|
||||
```
|
||||
|
||||
For debug development, using `mold` can reduce link time and memory use:
|
||||
|
||||
```bash
|
||||
cmake -S . -B build_debug -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Debug \
|
||||
-DONNX_MLIR_ACCELERATORS=PIM \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DMLIR_DIR=${MLIR_DIR} \
|
||||
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
|
||||
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold" \
|
||||
-DCMAKE_MODULE_LINKER_FLAGS="-fuse-ld=mold"
|
||||
```
|
||||
|
||||
Build the compiler with CMake:
|
||||
|
||||
```bash
|
||||
cmake --build ./build_release
|
||||
cmake --build ./build_debug
|
||||
```
|
||||
|
||||
Do not invoke `ninja` directly for this project; use `cmake --build` so CMake's
|
||||
configuration and generated shims stay consistent.
|
||||
|
||||
If a build fails because Protobuf headers are missing fixed-width integer
|
||||
definitions, patch the affected Protobuf-generated files by adding
|
||||
`#include <cstdint>`.
|
||||
|
||||
## Tests
|
||||
|
||||
The Rust simulator has its own tests:
|
||||
|
||||
```bash
|
||||
cd backend-simulators/pim/pim-simulator
|
||||
cargo test
|
||||
```
|
||||
|
||||
## Repository Layout
|
||||
|
||||
- `src/PIM/` - PIM accelerator implementation.
|
||||
- `test/PIM/` - PIM C++ unit tests.
|
||||
- `validation/` - functional validation scripts, ONNX operation tests, network
|
||||
slices, and pimsim config generation.
|
||||
- `backend-simulators/pim/pim-simulator/` - in-tree Rust functional simulator.
|
||||
- `backend-simulators/pim/pimsim-nn/` - performance simulator submodule.
|
||||
- `pimcomp_utils/` - local comparison helpers for PIMCOMP-NN.
|
||||
- `.github/actions/` and `.github/workflows/validate_operations.yml` - CI setup
|
||||
for MLIR/Protobuf caching, building Raptor, and validation.
|
||||
|
||||
+2121
-8
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,3 @@
|
||||
|
||||
[package]
|
||||
name = "pim-simulator"
|
||||
version = "0.1.0"
|
||||
@@ -13,8 +12,9 @@ name = "pimcore"
|
||||
path = "src/lib/pimcore.rs"
|
||||
|
||||
[features]
|
||||
default = ["tracing"]
|
||||
default = []
|
||||
tracing = []
|
||||
profile_time = ["dep:plotly", "dep:comfy-table", "dep:statrs"]
|
||||
|
||||
|
||||
|
||||
@@ -27,3 +27,10 @@ hex = "0"
|
||||
paste = "1"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
statrs = {version="0.16", optional=true}
|
||||
comfy-table = {version="7.1", optional=true}
|
||||
plotly = {version="0.8", optional=true}
|
||||
rayon = "1.12.0"
|
||||
faer = "0.24.0"
|
||||
faer-traits = "0.24.0"
|
||||
mimalloc = "0.1.50"
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
use mimalloc::MiMalloc;
|
||||
|
||||
#[global_allocator]
|
||||
static GLOBAL: MiMalloc = MiMalloc;
|
||||
|
||||
use anyhow::{Context, Result, bail};
|
||||
use clap::Parser;
|
||||
use glob::glob;
|
||||
use pimcore::binary_to_instruction::binary_to_executor;
|
||||
use pimcore::cpu::crossbar::Crossbar;
|
||||
use pimcore::json_to_instruction::json_to_executor;
|
||||
use pimcore::memory_manager::CoreMemory;
|
||||
use pimcore::tracing::TRACER;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::fs::{self, read_link};
|
||||
use std::io::Write;
|
||||
use std::fs::{self, File, read_link};
|
||||
use std::io::{BufReader, Write};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Program to simulate core execution configuration
|
||||
@@ -37,25 +43,31 @@ struct Args {
|
||||
|
||||
/// Comma separated list of (address,size) for memory output dump
|
||||
#[arg(short, long, value_delimiter = ',', num_args = 1.., value_name = "ADDR,SIZE")]
|
||||
dump: Vec<i32>,
|
||||
dump: Vec<usize>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let config_json = retrive_config(&args)?;
|
||||
let core_jsons = retrive_cores(&args)?;
|
||||
let mut core_inputs = retrive_cores(&args)?;
|
||||
let memory = retrive_memory(&args)?;
|
||||
let global_crossbars = get_crossbars(&config_json, &args).unwrap();
|
||||
let crossbars = map_crossbars_to_cores(&config_json, &args, &global_crossbars);
|
||||
let mut executor =
|
||||
json_to_executor::json_to_executor(config_json, core_jsons.iter(), crossbars);
|
||||
let mut executor = match &mut core_inputs {
|
||||
CoreInputs::Json(core_jsons) => {
|
||||
json_to_executor::json_to_executor(config_json, core_jsons, crossbars)
|
||||
}
|
||||
CoreInputs::Binary(core_bins) => {
|
||||
binary_to_executor(config_json, core_bins.iter(), crossbars)?
|
||||
}
|
||||
};
|
||||
set_memory(&mut executor, memory);
|
||||
TRACER
|
||||
.lock()
|
||||
.unwrap()
|
||||
.init(executor.cpu().num_core(), args.output.clone());
|
||||
executor.execute();
|
||||
executor.execute()?;
|
||||
dump_memory(executor, &args)?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -65,7 +77,7 @@ fn map_crossbars_to_cores<'c>(
|
||||
args: &Args,
|
||||
global_crossbars: &'c HashMap<String, Crossbar>,
|
||||
) -> Vec<Vec<&'c Crossbar>> {
|
||||
let mut res = Vec::new();
|
||||
let mut res = vec![Vec::new()];
|
||||
let num_cores = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
|
||||
|
||||
if let Some(folder) = args.folder.as_ref() {
|
||||
@@ -140,8 +152,7 @@ fn get_crossbars(config: &Value, args: &Args) -> anyhow::Result<HashMap<String,
|
||||
}
|
||||
|
||||
let bytes = std::fs::read(weight_file.path()).expect("Failed to read binary file");
|
||||
let mut crossbar =
|
||||
Crossbar::new(column_corssbar * 4, rows_crossbar, CoreMemory::new());
|
||||
let mut crossbar = Crossbar::new(column_corssbar * 4, rows_crossbar, CoreMemory::new());
|
||||
crossbar.execute_store(&bytes).unwrap();
|
||||
res.insert(
|
||||
weight_file
|
||||
@@ -157,7 +168,7 @@ fn get_crossbars(config: &Value, args: &Args) -> anyhow::Result<HashMap<String,
|
||||
}
|
||||
|
||||
fn dump_memory(mut executor: pimcore::Executable, args: &Args) -> Result<()> {
|
||||
let dumps: Vec<(i32, i32)> = args
|
||||
let dumps: Vec<(usize, usize)> = args
|
||||
.dump
|
||||
.chunks_exact(2)
|
||||
.map(|chunk| (chunk[0], chunk[1]))
|
||||
@@ -214,45 +225,82 @@ fn retrive_memory(args: &Args) -> Result<Vec<u8>> {
|
||||
Ok(memory_vector)
|
||||
}
|
||||
|
||||
fn retrive_cores(args: &Args) -> Result<Vec<Value>, anyhow::Error> {
|
||||
let mut core_jsons: Vec<Value> = Vec::new();
|
||||
if let Some(cores_override) = &args.cores {
|
||||
for core in cores_override {
|
||||
let content = fs::read_to_string(core)
|
||||
.with_context(|| format!("Failed to read core file: {:?}", cores_override))?;
|
||||
let json: Value =
|
||||
serde_json::from_str(&content).context("Failed to parse core json override")?;
|
||||
core_jsons.push(json);
|
||||
}
|
||||
} else if let Some(folder) = args.folder.as_ref() {
|
||||
let pattern = folder.join("core*.json");
|
||||
let pattern_str = pattern.to_str().context("Invalid path encoding")?;
|
||||
let mut paths: Vec<_> = glob(pattern_str)?.map(|x| x.unwrap()).collect();
|
||||
paths.sort_by_cached_key(|x| {
|
||||
let mut x = x
|
||||
.file_stem()
|
||||
.expect("Extracting the stem")
|
||||
.to_str()
|
||||
.expect("File not utf-8");
|
||||
x = &x[5..];
|
||||
x.parse::<i32>().unwrap()
|
||||
});
|
||||
enum CoreInputs {
|
||||
Json(Vec<BufReader<File>>),
|
||||
Binary(Vec<Vec<u8>>),
|
||||
}
|
||||
|
||||
if paths.is_empty() {
|
||||
bail!("No core*.json files found in {:?}", folder);
|
||||
fn retrive_cores(args: &Args) -> Result<CoreInputs, anyhow::Error> {
|
||||
if let Some(cores_override) = &args.cores {
|
||||
let first_extension = cores_override
|
||||
.first()
|
||||
.and_then(|path| path.extension())
|
||||
.and_then(|ext| ext.to_str())
|
||||
.unwrap_or_default();
|
||||
if first_extension == "pim" {
|
||||
let mut core_bins = Vec::with_capacity(cores_override.len());
|
||||
for core in cores_override {
|
||||
core_bins.push(
|
||||
fs::read(core)
|
||||
.with_context(|| format!("Failed to read binary core file: {:?}", core))?,
|
||||
);
|
||||
}
|
||||
return Ok(CoreInputs::Binary(core_bins));
|
||||
}
|
||||
for entry in paths {
|
||||
let path = entry;
|
||||
let content = fs::read_to_string(&path)
|
||||
.with_context(|| format!("Failed to read core file: {:?}", path))?;
|
||||
let json: Value = serde_json::from_str(&content)
|
||||
.with_context(|| format!("Failed to parse JSON in {:?}", path))?;
|
||||
core_jsons.push(json);
|
||||
let mut core_jsons_reader: Vec<BufReader<File>> = Vec::with_capacity(cores_override.len());
|
||||
for core in cores_override {
|
||||
let file = File::open(core)?;
|
||||
let reader = BufReader::new(file);
|
||||
core_jsons_reader.push(reader);
|
||||
}
|
||||
} else {
|
||||
bail!("Either --core or --folder must be provided to find core definitions.");
|
||||
return Ok(CoreInputs::Json(core_jsons_reader));
|
||||
}
|
||||
Ok(core_jsons)
|
||||
|
||||
if let Some(folder) = args.folder.as_ref() {
|
||||
let binary_pattern = folder.join("core*.pim");
|
||||
let binary_pattern_str = binary_pattern.to_str().context("Invalid path encoding")?;
|
||||
let mut binary_paths: Vec<_> = glob(binary_pattern_str)?.map(|x| x.unwrap()).collect();
|
||||
binary_paths.sort_by_cached_key(core_sort_key);
|
||||
if !binary_paths.is_empty() {
|
||||
let mut core_bins = Vec::with_capacity(binary_paths.len());
|
||||
for path in binary_paths {
|
||||
core_bins.push(
|
||||
fs::read(&path)
|
||||
.with_context(|| format!("Failed to read core file: {:?}", path))?,
|
||||
);
|
||||
}
|
||||
return Ok(CoreInputs::Binary(core_bins));
|
||||
}
|
||||
|
||||
let json_pattern = folder.join("core*.json");
|
||||
let json_pattern_str = json_pattern.to_str().context("Invalid path encoding")?;
|
||||
let mut json_paths: Vec<_> = glob(json_pattern_str)?.map(|x| x.unwrap()).collect();
|
||||
json_paths.sort_by_cached_key(core_sort_key);
|
||||
|
||||
if json_paths.is_empty() {
|
||||
bail!("No core*.pim or core*.json files found in {:?}", folder);
|
||||
}
|
||||
|
||||
let mut core_json_reader: Vec<BufReader<File>> = Vec::with_capacity(json_paths.len());
|
||||
for path in json_paths {
|
||||
let file = File::open(path)?;
|
||||
let reader = BufReader::new(file);
|
||||
core_json_reader.push(reader);
|
||||
}
|
||||
return Ok(CoreInputs::Json(core_json_reader));
|
||||
}
|
||||
|
||||
bail!("Either --core or --folder must be provided to find core definitions.");
|
||||
}
|
||||
|
||||
fn core_sort_key(path: &PathBuf) -> i32 {
|
||||
let mut stem = path
|
||||
.file_stem()
|
||||
.expect("Extracting the stem")
|
||||
.to_str()
|
||||
.expect("File not utf-8");
|
||||
stem = &stem[5..];
|
||||
stem.parse::<i32>().unwrap()
|
||||
}
|
||||
|
||||
fn retrive_config(args: &Args) -> Result<Value, anyhow::Error> {
|
||||
|
||||
@@ -0,0 +1,497 @@
|
||||
use crate::{
|
||||
CoreInstructionsBuilder, Executable,
|
||||
cpu::{CPU, crossbar::Crossbar},
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||
};
|
||||
use anyhow::{Context, Result, bail, ensure};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::convert::TryFrom;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
const MAGIC: &[u8; 4] = b"PIMB";
|
||||
const VERSION: u32 = 1;
|
||||
const HEADER_SIZE: usize = 12;
|
||||
const RECORD_SIZE: usize = 20;
|
||||
|
||||
macro_rules! add_name {
|
||||
($storage:ident, $opcode:literal, $name:literal) => {
|
||||
$storage.insert($opcode, $name);
|
||||
};
|
||||
}
|
||||
|
||||
static INSTRUCTIONS: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
|
||||
let mut hash = HashMap::new();
|
||||
add_name!(hash, 0, "nop");
|
||||
add_name!(hash, 1, "sldi");
|
||||
add_name!(hash, 2, "sld");
|
||||
add_name!(hash, 3, "sadd");
|
||||
add_name!(hash, 4, "ssub");
|
||||
add_name!(hash, 5, "smul");
|
||||
add_name!(hash, 6, "saddi");
|
||||
add_name!(hash, 7, "smuli");
|
||||
add_name!(hash, 8, "setbw");
|
||||
add_name!(hash, 9, "mvmul");
|
||||
add_name!(hash, 10, "vvadd");
|
||||
add_name!(hash, 11, "vvsub");
|
||||
add_name!(hash, 12, "vvmul");
|
||||
add_name!(hash, 13, "vvdmul");
|
||||
add_name!(hash, 14, "vvmax");
|
||||
add_name!(hash, 15, "vvsll");
|
||||
add_name!(hash, 16, "vvsra");
|
||||
add_name!(hash, 17, "vavg");
|
||||
add_name!(hash, 18, "vrelu");
|
||||
add_name!(hash, 19, "vtanh");
|
||||
add_name!(hash, 20, "vsigm");
|
||||
add_name!(hash, 21, "vsoftmax");
|
||||
add_name!(hash, 22, "vmv");
|
||||
add_name!(hash, 23, "vrsu");
|
||||
add_name!(hash, 24, "vrsl");
|
||||
add_name!(hash, 25, "ld");
|
||||
add_name!(hash, 26, "st");
|
||||
add_name!(hash, 27, "lldi");
|
||||
add_name!(hash, 28, "lmv");
|
||||
add_name!(hash, 29, "send");
|
||||
add_name!(hash, 30, "recv");
|
||||
add_name!(hash, 31, "wait");
|
||||
add_name!(hash, 32, "sync");
|
||||
hash
|
||||
});
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
struct InstructionRecord {
|
||||
opcode: u8,
|
||||
rd: u8,
|
||||
r1: u8,
|
||||
r2_or_imm: i32,
|
||||
generic1: i32,
|
||||
generic2: i32,
|
||||
generic3: i32,
|
||||
flags: u8,
|
||||
}
|
||||
|
||||
fn read_u32_le(bytes: &[u8], offset: usize) -> u32 {
|
||||
u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn read_i32_le(bytes: &[u8], offset: usize) -> i32 {
|
||||
i32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap())
|
||||
}
|
||||
|
||||
fn parse_binary_records(bytes: &[u8]) -> Result<Vec<InstructionRecord>> {
|
||||
ensure!(bytes.len() >= HEADER_SIZE, "binary core file too small");
|
||||
ensure!(&bytes[0..4] == MAGIC, "invalid PIM binary magic");
|
||||
|
||||
let version = read_u32_le(bytes, 4);
|
||||
ensure!(
|
||||
version == VERSION,
|
||||
"unsupported PIM binary version {version}"
|
||||
);
|
||||
|
||||
let instruction_count = read_u32_le(bytes, 8) as usize;
|
||||
let expected_len = HEADER_SIZE + instruction_count * RECORD_SIZE;
|
||||
ensure!(
|
||||
bytes.len() == expected_len,
|
||||
"PIM binary size mismatch: expected {expected_len} bytes, got {}",
|
||||
bytes.len()
|
||||
);
|
||||
|
||||
let mut records = Vec::with_capacity(instruction_count);
|
||||
for index in 0..instruction_count {
|
||||
let base = HEADER_SIZE + index * RECORD_SIZE;
|
||||
records.push(InstructionRecord {
|
||||
opcode: bytes[base],
|
||||
rd: bytes[base + 1],
|
||||
r1: bytes[base + 2],
|
||||
flags: bytes[base + 3],
|
||||
r2_or_imm: read_i32_le(bytes, base + 4),
|
||||
generic1: read_i32_le(bytes, base + 8),
|
||||
generic2: read_i32_le(bytes, base + 12),
|
||||
generic3: read_i32_le(bytes, base + 16),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
fn append_record(
|
||||
inst_builder: &mut InstructionsBuilder,
|
||||
inst_data_builder: &mut InstructionDataBuilder,
|
||||
record: InstructionRecord,
|
||||
) -> Result<()> {
|
||||
let InstructionRecord {
|
||||
opcode,
|
||||
rd,
|
||||
r1,
|
||||
r2_or_imm,
|
||||
generic1,
|
||||
generic2,
|
||||
generic3,
|
||||
flags: _,
|
||||
} = record;
|
||||
|
||||
match opcode {
|
||||
0 => {}
|
||||
1 => {
|
||||
inst_data_builder.set_rd_u8(rd).set_imm(r2_or_imm);
|
||||
inst_builder.make_inst(sldi, inst_data_builder.build());
|
||||
}
|
||||
2 => {
|
||||
inst_data_builder
|
||||
.set_rd_u8(rd)
|
||||
.set_r1_u8(r1)
|
||||
.set_offset_select(generic1)
|
||||
.set_offset_value(generic2);
|
||||
inst_builder.make_inst(sld, inst_data_builder.build());
|
||||
}
|
||||
3 => {
|
||||
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
|
||||
inst_builder.make_inst(sadd, inst_data_builder.build());
|
||||
}
|
||||
4 => {
|
||||
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
|
||||
inst_builder.make_inst(ssub, inst_data_builder.build());
|
||||
}
|
||||
5 => {
|
||||
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
|
||||
inst_builder.make_inst(smul, inst_data_builder.build());
|
||||
}
|
||||
6 => {
|
||||
inst_data_builder.set_rdr1imm_u8(rd, r1, r2_or_imm);
|
||||
inst_builder.make_inst(saddi, inst_data_builder.build());
|
||||
}
|
||||
7 => {
|
||||
inst_data_builder.set_rdr1imm_u8(rd, r1, r2_or_imm);
|
||||
inst_builder.make_inst(smuli, inst_data_builder.build());
|
||||
}
|
||||
8 => {
|
||||
inst_data_builder.set_ibiw_obiw(generic1, generic2);
|
||||
inst_builder.make_inst(setbw, inst_data_builder.build());
|
||||
}
|
||||
9 => {
|
||||
inst_data_builder
|
||||
.set_rd_u8(rd)
|
||||
.set_r1_u8(r1)
|
||||
.set_mbiw_immrelu_immgroup(r2_or_imm, generic1, generic2);
|
||||
inst_builder.make_inst(mvmul, inst_data_builder.build());
|
||||
}
|
||||
10 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvadd, inst_data_builder.build());
|
||||
}
|
||||
11 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvsub, inst_data_builder.build());
|
||||
}
|
||||
12 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvmul, inst_data_builder.build());
|
||||
}
|
||||
13 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvdmul, inst_data_builder.build());
|
||||
}
|
||||
14 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvmax, inst_data_builder.build());
|
||||
}
|
||||
15 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvsll, inst_data_builder.build());
|
||||
}
|
||||
16 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vvsra, inst_data_builder.build());
|
||||
}
|
||||
17 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vavg, inst_data_builder.build());
|
||||
}
|
||||
18 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vrelu, inst_data_builder.build());
|
||||
}
|
||||
19 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vtanh, inst_data_builder.build());
|
||||
}
|
||||
20 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vsigm, inst_data_builder.build());
|
||||
}
|
||||
21 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vsoftmax, inst_data_builder.build());
|
||||
}
|
||||
22 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vmv, inst_data_builder.build());
|
||||
}
|
||||
23 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vrsu, inst_data_builder.build());
|
||||
}
|
||||
24 => {
|
||||
inst_data_builder
|
||||
.set_rdr1r2_u8(rd, r1, r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(vrsl, inst_data_builder.build());
|
||||
}
|
||||
25 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(ld, inst_data_builder.build());
|
||||
}
|
||||
26 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(st, inst_data_builder.build());
|
||||
}
|
||||
27 => {
|
||||
inst_data_builder
|
||||
.set_rd_u8(rd)
|
||||
.set_imm(r2_or_imm)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(lldi, inst_data_builder.build());
|
||||
}
|
||||
28 => {
|
||||
inst_data_builder
|
||||
.set_rdr1_u8(rd, r1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(lmv, inst_data_builder.build());
|
||||
}
|
||||
29 => {
|
||||
inst_data_builder
|
||||
.set_rd_u8(rd)
|
||||
.set_imm_core(r2_or_imm + 1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(send, inst_data_builder.build());
|
||||
}
|
||||
30 => {
|
||||
inst_data_builder
|
||||
.set_rd_u8(rd)
|
||||
.set_imm_core(r2_or_imm + 1)
|
||||
.set_imm_len(generic3)
|
||||
.set_offset_select_value(generic1, generic2);
|
||||
inst_builder.make_inst(recv, inst_data_builder.build());
|
||||
}
|
||||
31 => {
|
||||
inst_builder.make_inst(wait, inst_data_builder.build());
|
||||
}
|
||||
32 => {
|
||||
inst_builder.make_inst(sync, inst_data_builder.build());
|
||||
}
|
||||
_ => bail!("unsupported PIM binary opcode {opcode}"),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn binary_to_instructions(
|
||||
core_bytes: &[u8],
|
||||
core_index: i32,
|
||||
) -> Result<Vec<crate::instruction_set::Instruction>> {
|
||||
let records = parse_binary_records(core_bytes)?;
|
||||
let mut insts_builder = InstructionsBuilder::new();
|
||||
let mut inst_data_builder = InstructionDataBuilder::new();
|
||||
inst_data_builder
|
||||
.set_core_indx_u16(u16::try_from(core_index).expect("core index does not fit in u16"))
|
||||
.fix_core_indx();
|
||||
|
||||
for record in records {
|
||||
let opcode = record.opcode;
|
||||
let name = INSTRUCTIONS
|
||||
.get(&(opcode as usize))
|
||||
.copied()
|
||||
.unwrap_or("<unknown>");
|
||||
|
||||
append_record(&mut insts_builder, &mut inst_data_builder, record).with_context(|| {
|
||||
format!(
|
||||
"while decoding binary instruction for core {core_index}: opcode {opcode} ({name})"
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(insts_builder.build())
|
||||
}
|
||||
|
||||
pub fn binary_to_executor<'a, 'b>(
|
||||
config: Value,
|
||||
cores: impl Iterator<Item = &'b Vec<u8>>,
|
||||
crossbars: Vec<Vec<&'a Crossbar>>,
|
||||
) -> Result<Executable<'a>> {
|
||||
let core_cnt = config
|
||||
.get("core_cnt")
|
||||
.context("missing core_cnt in config")?
|
||||
.as_i64()
|
||||
.context("core_cnt is not an integer")? as i32;
|
||||
|
||||
let cpu = CPU::new(core_cnt, crossbars);
|
||||
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
|
||||
for (external_core_indx, core_bytes) in cores.enumerate() {
|
||||
let core_indx = external_core_indx as i32 + 1;
|
||||
let instructions = binary_to_instructions(core_bytes, core_indx)?;
|
||||
core_insts_builder.set_core(core_indx, instructions);
|
||||
}
|
||||
|
||||
Ok(Executable::new(cpu, core_insts_builder.build()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
HEADER_SIZE, InstructionRecord, MAGIC, RECORD_SIZE, VERSION, binary_to_instructions,
|
||||
};
|
||||
use crate::{
|
||||
functor_to_name,
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
|
||||
json_to_instruction::json_isa::json_to_instruction,
|
||||
};
|
||||
|
||||
fn encode_record(record: InstructionRecord, dst: &mut Vec<u8>) {
|
||||
dst.push(record.opcode);
|
||||
dst.push(record.rd);
|
||||
dst.push(record.r1);
|
||||
dst.push(record.flags);
|
||||
dst.extend_from_slice(&record.r2_or_imm.to_le_bytes());
|
||||
dst.extend_from_slice(&record.generic1.to_le_bytes());
|
||||
dst.extend_from_slice(&record.generic2.to_le_bytes());
|
||||
dst.extend_from_slice(&record.generic3.to_le_bytes());
|
||||
}
|
||||
|
||||
fn binary_blob(records: &[InstructionRecord]) -> Vec<u8> {
|
||||
let mut blob = Vec::with_capacity(HEADER_SIZE + records.len() * RECORD_SIZE);
|
||||
blob.extend_from_slice(MAGIC);
|
||||
blob.extend_from_slice(&VERSION.to_le_bytes());
|
||||
blob.extend_from_slice(&(records.len() as u32).to_le_bytes());
|
||||
for &record in records {
|
||||
encode_record(record, &mut blob);
|
||||
}
|
||||
blob
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_and_binary_decoders_match_for_representative_ops() {
|
||||
let json_program = [
|
||||
r#"{"imm":64,"op":"sldi","rd":0}"#,
|
||||
r#"{"imm":128,"op":"sldi","rd":1}"#,
|
||||
r#"{"len":16,"offset":{"offset_select":0,"offset_value":0},"op":"lmv","rd":0,"rs1":1}"#,
|
||||
r#"{"group":3,"mbiw":8,"op":"mvmul","rd":0,"relu":0,"rs1":1}"#,
|
||||
r#"{"len":16,"offset":{"offset_select":0,"offset_value":0},"op":"vvadd","rd":0,"rs1":1,"rs2":2}"#,
|
||||
r#"{"core":2,"offset":{"offset_select":0,"offset_value":0},"op":"send","rd":0,"size":16}"#,
|
||||
];
|
||||
|
||||
let binary_program = binary_blob(&[
|
||||
InstructionRecord {
|
||||
opcode: 1,
|
||||
rd: 0,
|
||||
r2_or_imm: 64,
|
||||
..Default::default()
|
||||
},
|
||||
InstructionRecord {
|
||||
opcode: 1,
|
||||
rd: 1,
|
||||
r2_or_imm: 128,
|
||||
..Default::default()
|
||||
},
|
||||
InstructionRecord {
|
||||
opcode: 28,
|
||||
rd: 0,
|
||||
r1: 1,
|
||||
generic3: 16,
|
||||
..Default::default()
|
||||
},
|
||||
InstructionRecord {
|
||||
opcode: 9,
|
||||
rd: 0,
|
||||
r1: 1,
|
||||
r2_or_imm: 8,
|
||||
generic2: 3,
|
||||
..Default::default()
|
||||
},
|
||||
InstructionRecord {
|
||||
opcode: 10,
|
||||
rd: 0,
|
||||
r1: 1,
|
||||
r2_or_imm: 2,
|
||||
generic3: 16,
|
||||
..Default::default()
|
||||
},
|
||||
InstructionRecord {
|
||||
opcode: 29,
|
||||
rd: 0,
|
||||
r2_or_imm: 2,
|
||||
generic3: 16,
|
||||
..Default::default()
|
||||
},
|
||||
]);
|
||||
|
||||
let mut json_builder = InstructionsBuilder::new();
|
||||
let mut json_data_builder = InstructionDataBuilder::new();
|
||||
json_data_builder.set_core_indx(1).fix_core_indx();
|
||||
for inst in json_program {
|
||||
let value = serde_json::from_str(inst).unwrap();
|
||||
json_to_instruction(&mut json_builder, &mut json_data_builder, &value);
|
||||
}
|
||||
let json_instructions = json_builder.build();
|
||||
let binary_instructions = binary_to_instructions(&binary_program, 1).unwrap();
|
||||
|
||||
assert_eq!(json_instructions.len(), binary_instructions.len());
|
||||
for (json_inst, binary_inst) in json_instructions.iter().zip(binary_instructions.iter()) {
|
||||
assert_eq!(
|
||||
functor_to_name(json_inst.functor as usize),
|
||||
functor_to_name(binary_inst.functor as usize)
|
||||
);
|
||||
assert_eq!(json_inst.data, binary_inst.data);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::utility::AddressArg;
|
||||
use std::{collections::HashMap, fmt::Debug};
|
||||
use anyhow::{Context, Result, ensure};
|
||||
|
||||
@@ -9,6 +10,7 @@ use crate::{
|
||||
|
||||
pub mod crossbar;
|
||||
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CPU<'a> {
|
||||
cores: Box<[Core<'a>]>,
|
||||
@@ -91,30 +93,26 @@ impl<'a> Core<'a> {
|
||||
self.memory.execute_load()
|
||||
}
|
||||
|
||||
pub fn execute_store<T>(&mut self, address: impl TryToUsize, element: &[T]) -> Result<()>
|
||||
pub fn execute_store<T>(&mut self, address: impl AddressArg, element: &[T]) -> Result<()>
|
||||
where
|
||||
T: MemoryStorable,
|
||||
{
|
||||
let address = address.try_into().context("address can not be negative")?;
|
||||
let address = address.to_address_usize()?;
|
||||
self.memory.execute_store(address, element)
|
||||
}
|
||||
|
||||
pub fn reserve_load(
|
||||
&mut self,
|
||||
address: impl TryToUsize,
|
||||
address: impl AddressArg,
|
||||
size: impl TryToUsize,
|
||||
) -> Result<&mut CoreMemory> {
|
||||
let address = address.try_into().context("address can not be negative")?;
|
||||
let address = address.to_address_usize()?;
|
||||
let size = size.try_into().context("size can not be negative")?;
|
||||
self.memory.reserve_load(address, size)
|
||||
}
|
||||
|
||||
pub fn set_register(&mut self, index: impl TryToUsize, value: i32) {
|
||||
let index = index.try_into().expect("index can not be negative");
|
||||
assert!(
|
||||
value >= 0,
|
||||
"Register cannot be negative if happens remove this and go check where it's used as usize"
|
||||
);
|
||||
self.registers[index] = value;
|
||||
}
|
||||
|
||||
@@ -123,11 +121,11 @@ impl<'a> Core<'a> {
|
||||
self.registers[index]
|
||||
}
|
||||
|
||||
pub fn load<T>(&mut self, address: impl TryToUsize, size: impl TryToUsize) -> Result<Vec<&[T]>>
|
||||
pub fn load<T>(&mut self, address: impl AddressArg, size: impl TryToUsize) -> Result<Vec<&[T]>>
|
||||
where
|
||||
T: MemoryStorable,
|
||||
{
|
||||
let address = address.try_into().context("address can not be negative")?;
|
||||
let address = address.to_address_usize()?;
|
||||
let size = size.try_into().context("size can not be negative")?;
|
||||
self.memory.load(address, size)
|
||||
}
|
||||
@@ -141,8 +139,8 @@ impl<'a> Core<'a> {
|
||||
(memory, crossbars)
|
||||
}
|
||||
|
||||
pub fn memset(&mut self, address: impl TryToUsize, size: impl TryToUsize, val: u8) -> Result<()> {
|
||||
let address = address.try_into().context("address can not be negative")?;
|
||||
pub fn memset(&mut self, address: impl AddressArg, size: impl TryToUsize, val: u8) -> Result<()> {
|
||||
let address = address.to_address_usize()?;
|
||||
let size = size.try_into().context("size can not be negative")?;
|
||||
self.memory.memset(address, size, val)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use paste::paste;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
|
||||
pub struct InstructionData {
|
||||
core_indx: i32,
|
||||
rd: i32,
|
||||
r1: i32,
|
||||
core_indx: u16,
|
||||
rd: u8,
|
||||
r1: u8,
|
||||
//r2 imm mbiw imm_core
|
||||
r2_or_imm: i32,
|
||||
//offset_select imm_relu ibiw
|
||||
@@ -16,18 +17,30 @@ pub struct InstructionData {
|
||||
}
|
||||
|
||||
impl InstructionData {
|
||||
pub fn core_indx(&self) -> i32 {
|
||||
pub fn core_indx_u16(&self) -> u16 {
|
||||
self.core_indx
|
||||
}
|
||||
|
||||
pub fn rd(&self) -> i32 {
|
||||
pub fn core_indx(&self) -> i32 {
|
||||
i32::from(self.core_indx)
|
||||
}
|
||||
|
||||
pub fn rd_u8(&self) -> u8 {
|
||||
self.rd
|
||||
}
|
||||
|
||||
pub fn r1(&self) -> i32 {
|
||||
pub fn rd(&self) -> i32 {
|
||||
i32::from(self.rd)
|
||||
}
|
||||
|
||||
pub fn r1_u8(&self) -> u8 {
|
||||
self.r1
|
||||
}
|
||||
|
||||
pub fn r1(&self) -> i32 {
|
||||
i32::from(self.r1)
|
||||
}
|
||||
|
||||
pub fn r2(&self) -> i32 {
|
||||
self.r2_or_imm
|
||||
}
|
||||
@@ -49,26 +62,26 @@ impl InstructionData {
|
||||
}
|
||||
|
||||
pub fn get_core_rd_r1(&self) -> (i32, i32, i32) {
|
||||
(self.core_indx, self.rd, self.r1)
|
||||
(self.core_indx(), self.rd(), self.r1())
|
||||
}
|
||||
|
||||
pub fn get_core_rd_r1_r2(&self) -> (i32, i32, i32, i32) {
|
||||
(self.core_indx, self.rd, self.r1, self.r2_or_imm)
|
||||
(self.core_indx(), self.rd(), self.r1(), self.r2_or_imm)
|
||||
}
|
||||
|
||||
pub fn get_core_rd_imm(&self) -> (i32, i32, i32) {
|
||||
(self.core_indx, self.rd, self.r2_or_imm)
|
||||
(self.core_indx(), self.rd(), self.r2_or_imm)
|
||||
}
|
||||
|
||||
pub fn get_core_rd_r1_imm(&self) -> (i32, i32, i32, i32) {
|
||||
(self.core_indx, self.rd, self.r1, self.r2_or_imm)
|
||||
(self.core_indx(), self.rd(), self.r1(), self.r2_or_imm)
|
||||
}
|
||||
|
||||
pub fn get_core_rd_r1_r2_immlen_offset(&self) -> (i32, i32, i32, i32, i32, i32, i32) {
|
||||
(
|
||||
self.core_indx,
|
||||
self.rd,
|
||||
self.r1,
|
||||
self.core_indx(),
|
||||
self.rd(),
|
||||
self.r1(),
|
||||
self.r2_or_imm,
|
||||
self.generic3,
|
||||
self.generic1,
|
||||
@@ -78,9 +91,9 @@ impl InstructionData {
|
||||
|
||||
pub fn get_core_rd_r1_mbiw_immrelu_immgroup(&self) -> (i32, i32, i32, i32, i32, i32) {
|
||||
(
|
||||
self.core_indx,
|
||||
self.rd,
|
||||
self.r1,
|
||||
self.core_indx(),
|
||||
self.rd(),
|
||||
self.r1(),
|
||||
self.r2_or_imm,
|
||||
self.generic1,
|
||||
self.generic2,
|
||||
@@ -100,7 +113,7 @@ impl InstructionData {
|
||||
}
|
||||
|
||||
pub(crate) fn get_core_immcore(&self) -> (i32, i32) {
|
||||
(self.core_indx, self.r2_or_imm)
|
||||
(self.core_indx(), self.r2_or_imm)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,6 +229,18 @@ impl InstructionDataBuilder {
|
||||
common_getter_setter![imm_group];
|
||||
common_getter_setter![imm_core];
|
||||
|
||||
pub fn set_core_indx_u16(&mut self, val: u16) -> &mut Self {
|
||||
self.set_core_indx(i32::from(val))
|
||||
}
|
||||
|
||||
pub fn set_rd_u8(&mut self, val: u8) -> &mut Self {
|
||||
self.set_rd(i32::from(val))
|
||||
}
|
||||
|
||||
pub fn set_r1_u8(&mut self, val: u8) -> &mut Self {
|
||||
self.set_r1(i32::from(val))
|
||||
}
|
||||
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
core_indx: Fixer::Edit(0),
|
||||
@@ -254,20 +279,16 @@ impl InstructionDataBuilder {
|
||||
|
||||
fn check_sanity(&self) {
|
||||
assert!(!(self.get_r2() != 0 && self.get_imm() != 0 && self.get_mbiw() != 0 && self.get_imm_core() != 0));
|
||||
assert!(
|
||||
!(self.get_ibiw() != 0 && self.get_offset_select() != 0 && self.get_imm_relu() != 0)
|
||||
);
|
||||
assert!(
|
||||
!(self.get_obiw() != 0 && self.get_offset_value() != 0 && self.get_imm_group() != 0)
|
||||
);
|
||||
assert!(!(self.get_ibiw() != 0 && self.get_offset_select() != 0 && self.get_imm_relu() != 0));
|
||||
assert!(!(self.get_obiw() != 0 && self.get_offset_value() != 0 && self.get_imm_group() != 0));
|
||||
}
|
||||
|
||||
pub fn build(&mut self) -> InstructionData {
|
||||
self.check_sanity();
|
||||
let inst_data = InstructionData {
|
||||
core_indx: self.get_core_indx(),
|
||||
rd: self.get_rd(),
|
||||
r1: self.get_r1(),
|
||||
core_indx: u16::try_from(self.get_core_indx()).expect("core index does not fit in u16"),
|
||||
rd: u8::try_from(self.get_rd()).expect("rd does not fit in u8"),
|
||||
r1: u8::try_from(self.get_r1()).expect("r1 does not fit in u8"),
|
||||
r2_or_imm: self.get_r2() + self.get_imm() + self.get_mbiw() + self.get_imm_core(),
|
||||
generic1: self.get_offset_select() + self.get_ibiw() + self.get_imm_relu(),
|
||||
generic2: self.get_offset_value() + self.get_obiw() + self.get_imm_group(),
|
||||
@@ -281,6 +302,10 @@ impl InstructionDataBuilder {
|
||||
self.set_rd(rd).set_r1(r1).set_r2(r2)
|
||||
}
|
||||
|
||||
pub fn set_rdr1r2_u8(&mut self, rd: u8, r1: u8, r2: i32) -> &mut Self {
|
||||
self.set_rd_u8(rd).set_r1_u8(r1).set_r2(r2)
|
||||
}
|
||||
|
||||
pub fn set_offset_select_value(&mut self, offset_select: i32, offset_value: i32) -> &mut Self {
|
||||
self.set_offset_select(offset_select)
|
||||
.set_offset_value(offset_value)
|
||||
@@ -290,14 +315,26 @@ impl InstructionDataBuilder {
|
||||
self.set_rd(rd).set_r1(r1).set_imm(imm)
|
||||
}
|
||||
|
||||
pub fn set_rdr1imm_u8(&mut self, rd: u8, r1: u8, imm: i32) -> &mut Self {
|
||||
self.set_rd_u8(rd).set_r1_u8(r1).set_imm(imm)
|
||||
}
|
||||
|
||||
pub fn set_rdr1(&mut self, rd: i32, r1: i32) -> &mut Self {
|
||||
self.set_rd(rd).set_r1(r1)
|
||||
}
|
||||
|
||||
pub fn set_rdr1_u8(&mut self, rd: u8, r1: u8) -> &mut Self {
|
||||
self.set_rd_u8(rd).set_r1_u8(r1)
|
||||
}
|
||||
|
||||
pub fn set_rdimm(&mut self, rd: i32, imm: i32) -> &mut Self {
|
||||
self.set_rd(rd).set_imm(imm)
|
||||
}
|
||||
|
||||
pub fn set_rdimm_u8(&mut self, rd: u8, imm: i32) -> &mut Self {
|
||||
self.set_rd_u8(rd).set_imm(imm)
|
||||
}
|
||||
|
||||
pub fn set_ibiw_obiw(&mut self, ibiw: i32, obiw: i32) -> &mut Self {
|
||||
self.set_ibiw(ibiw).set_obiw(obiw)
|
||||
}
|
||||
|
||||
@@ -1,17 +1,22 @@
|
||||
use crate::{
|
||||
cpu::{CPU, crossbar}, instruction_set::{
|
||||
cpu::{CPU, crossbar},
|
||||
instruction_set::{
|
||||
Instruction, InstructionData, InstructionStatus, InstructionType, VectorBitWith,
|
||||
helper::add_all,
|
||||
}, memory_manager::{
|
||||
},
|
||||
memory_manager::{
|
||||
MemoryStorable,
|
||||
type_traits::{FromFloat, UpcastDestTraits, UpcastSlice},
|
||||
}, tracing::TRACER, utility::{add_offset_r1, add_offset_r2, add_offset_rd}
|
||||
},
|
||||
tracing::TRACER,
|
||||
utility::{add_offset_r1, add_offset_r2, add_offset_rd},
|
||||
};
|
||||
use aligned_vec::{AVec, ConstAlign};
|
||||
use anyhow::{Context, Result, ensure};
|
||||
use rayon::prelude::*;
|
||||
|
||||
use paste::paste;
|
||||
use std::{borrow::Cow, cell::OnceCell, collections::HashMap};
|
||||
use std::{borrow::Cow, cell::OnceCell, collections::HashMap };
|
||||
use std::{collections::HashSet, sync::LazyLock};
|
||||
|
||||
macro_rules! add_name {
|
||||
@@ -30,7 +35,7 @@ macro_rules! add_name_simd {
|
||||
};
|
||||
}
|
||||
|
||||
static NAMES: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
|
||||
pub static NAMES: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
|
||||
let mut hash = HashMap::new();
|
||||
add_name!(hash, sldi);
|
||||
add_name!(hash, sld);
|
||||
@@ -76,8 +81,8 @@ pub fn functor_to_name(functor: usize) -> &'static str {
|
||||
///////////////////////////////////////////////////////////////
|
||||
/////////////////Scalar/register Instructions//////////////////
|
||||
///////////////////////////////////////////////////////////////
|
||||
pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
{
|
||||
#[inline(never)]
|
||||
pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_sldi(cores, data);
|
||||
let (core_indx, rd, imm) = data.get_core_rd_imm();
|
||||
let core = cores.core(core_indx);
|
||||
@@ -86,6 +91,7 @@ pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn sld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_sld(cores, data);
|
||||
let (core_indx, rd, r1) = data.get_core_rd_r1();
|
||||
@@ -100,6 +106,7 @@ pub fn sld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn sadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_sadd(cores, data);
|
||||
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
|
||||
@@ -110,6 +117,7 @@ pub fn sadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn ssub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_ssub(cores, data);
|
||||
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
|
||||
@@ -120,6 +128,7 @@ pub fn ssub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn smul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_smul(cores, data);
|
||||
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
|
||||
@@ -130,6 +139,7 @@ pub fn smul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn saddi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_saddi(cores, data);
|
||||
let (core_indx, rd, r1, imm) = data.get_core_rd_r1_imm();
|
||||
@@ -139,6 +149,7 @@ pub fn saddi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn smuli(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_smuli(cores, data);
|
||||
let (core_indx, rd, r1, imm) = data.get_core_rd_r1_imm();
|
||||
@@ -213,14 +224,17 @@ pub fn is_setbw(functor: InstructionType) -> bool {
|
||||
functor as usize == setbw as *const () as usize
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn setbw(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, this instruction is resolved in the construction phase");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn mvmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn mvm_impl_internal<F, M, T>(
|
||||
cores: &mut CPU,
|
||||
data: InstructionData,
|
||||
@@ -229,25 +243,30 @@ where
|
||||
[F]: UpcastSlice<T> + UpcastSlice<M>,
|
||||
[M]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
|
||||
// Add faer::ComplexField HERE, directly bounding M for this function only
|
||||
M: UpcastDestTraits<M> + MemoryStorable + FromFloat + faer_traits::ComplexField,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
TRACER.lock().unwrap().pre_mvm::<F,M,T>(cores, data);
|
||||
TRACER.lock().unwrap().pre_mvm::<F, M, T>(cores, data);
|
||||
|
||||
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
|
||||
let group: usize = group.try_into().context("group can not be negative")?;
|
||||
|
||||
let core = cores.core(core_indx);
|
||||
let r1_val = core.register(r1);
|
||||
let rd_val = core.register(rd);
|
||||
|
||||
let (memory, crossbars) = core.get_memory_crossbar();
|
||||
let crossbar = crossbars.get_mut(group).unwrap();
|
||||
let crossbar_stored_bytes = crossbar.stored_bytes();
|
||||
let crossbar_byte_width = crossbar.width();
|
||||
//Fix this
|
||||
|
||||
let crossbar_elem_width = crossbar_byte_width / size_of::<M>();
|
||||
ensure!(
|
||||
crossbar_byte_width & size_of::<M>() == 0,
|
||||
crossbar_byte_width % size_of::<M>() == 0,
|
||||
"M not divisor of the crosbbar size"
|
||||
);
|
||||
|
||||
let crossbar_height = crossbar.height();
|
||||
let crossbar_byte_size = crossbar_byte_width * crossbar_height;
|
||||
|
||||
@@ -257,19 +276,29 @@ where
|
||||
let load = loads[0];
|
||||
let vec: Cow<[M]> = load.up();
|
||||
let matrix = crossbar.load::<M>(crossbar_byte_size)?[0];
|
||||
let mut res = Vec::with_capacity(crossbar_elem_width);
|
||||
let mut partial :AVec<M, _> = AVec::<M, ConstAlign<64>>::with_capacity(64, vec.len());
|
||||
partial.resize(vec.len(), M::from_f32(0.0));
|
||||
|
||||
for x in 0..crossbar_elem_width {
|
||||
partial[0] = vec[0] * matrix[x];
|
||||
for y in 1..crossbar_height {
|
||||
partial[y] = vec[y] * matrix[y * crossbar_elem_width + x];
|
||||
}
|
||||
// --- FAER IMPLEMENTATION ---
|
||||
|
||||
// 1. Explicitly create a Matrix Reference (MatRef)
|
||||
let matrix_view = faer::mat::MatRef::from_row_major_slice(
|
||||
matrix.as_ref(),
|
||||
crossbar_height,
|
||||
crossbar_elem_width,
|
||||
);
|
||||
|
||||
// 2. Explicitly create a Column Vector Reference (ColRef)
|
||||
// Using `ColRef` here guarantees we don't accidentally get a RowRef (Fixes E0277)
|
||||
let vec_view = faer::col::ColRef::from_slice(vec.as_ref());
|
||||
|
||||
let res_col: faer::col::Col<M> = matrix_view.transpose() * vec_view;
|
||||
|
||||
// 4. Convert back to standard Rust Vec
|
||||
// try_as_slice() returns an Option<&[M]>.
|
||||
// We can safely unwrap() because a freshly allocated, owned Col is ALWAYS contiguous!
|
||||
let mut res: Vec<M> = (0..crossbar_elem_width).map(|i| res_col[i]).collect();
|
||||
|
||||
// --- END FAER ---
|
||||
|
||||
let mut acc = add_all(partial.as_slice());
|
||||
res.push(acc);
|
||||
}
|
||||
if relu != 0 {
|
||||
res.iter_mut().for_each(|x| {
|
||||
if *x < M::from_f32(0.0) {
|
||||
@@ -277,16 +306,20 @@ where
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
ensure!(
|
||||
res.len() == crossbar_elem_width,
|
||||
"mvm generate a vector bigger thant it's requested elements"
|
||||
"mvm generate a vector bigger thant it's requested elements"
|
||||
);
|
||||
|
||||
let res_up: Cow<[T]> = res.as_slice().up();
|
||||
core.execute_store(rd_val, res_up.as_ref());
|
||||
TRACER.lock().unwrap().post_mvm::<F,M,T>(cores, data);
|
||||
|
||||
TRACER.lock().unwrap().post_mvm::<F, M, T>(cores, data);
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn mvmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T> + UpcastSlice<f32> + UpcastSlice<f64>,
|
||||
@@ -307,17 +340,19 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vvadd_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
TRACER.lock().unwrap().pre_vvadd::<F,T>(cores, data);
|
||||
TRACER.lock().unwrap().pre_vvadd::<F, T>(cores, data);
|
||||
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
let core = cores.core(core_indx);
|
||||
@@ -345,21 +380,23 @@ where
|
||||
);
|
||||
let res_up: Cow<[T]> = res.as_slice().up();
|
||||
core.execute_store(rd_val, res_up.as_ref());
|
||||
TRACER.lock().unwrap().post_vvadd::<F,T>(cores, data);
|
||||
TRACER.lock().unwrap().post_vvadd::<F, T>(cores, data);
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvsub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vvsub_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
TRACER.lock().unwrap().pre_vvsub::<F,T>(cores, data);
|
||||
TRACER.lock().unwrap().pre_vvsub::<F, T>(cores, data);
|
||||
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
let core = cores.core(core_indx);
|
||||
@@ -394,13 +431,14 @@ pub fn vvmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vvmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
TRACER.lock().unwrap().pre_vvmul::<F,T>(cores, data);
|
||||
TRACER.lock().unwrap().pre_vvmul::<F, T>(cores, data);
|
||||
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
let core = cores.core(core_indx);
|
||||
@@ -430,17 +468,19 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvdmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vvdmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
TRACER.lock().unwrap().pre_vvdmul::<F,T>(cores, data);
|
||||
TRACER.lock().unwrap().pre_vvdmul::<F, T>(cores, data);
|
||||
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
let core = cores.core(core_indx);
|
||||
@@ -466,17 +506,19 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vvmax_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
TRACER.lock().unwrap().pre_vvmax::<F,T>(cores, data);
|
||||
TRACER.lock().unwrap().pre_vvmax::<F, T>(cores, data);
|
||||
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
let core = cores.core(core_indx);
|
||||
@@ -503,29 +545,33 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvsll(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!(
|
||||
"Shift left on floating point what does it means? who has generated this instruction???"
|
||||
);
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vvsra(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!(
|
||||
"Shift right on floating point what does it means? who has generated this instruction???"
|
||||
);
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vavg(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vavg_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
TRACER.lock().unwrap().pre_vavg::<F,T>(cores, data);
|
||||
TRACER.lock().unwrap().pre_vavg::<F, T>(cores, data);
|
||||
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
let core = cores.core(core_indx);
|
||||
@@ -533,7 +579,10 @@ where
|
||||
let r2_val = r2;
|
||||
ensure!(r2_val == 1, "Stride different than 1 not supported");
|
||||
let rd_val = core.register(rd);
|
||||
ensure!(offset_select == 1, "Offset select cannot be different from 1");
|
||||
ensure!(
|
||||
offset_select == 1,
|
||||
"Offset select cannot be different from 1"
|
||||
);
|
||||
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
|
||||
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
|
||||
let load1 = loads[0];
|
||||
@@ -545,17 +594,19 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vrelu(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vrelu_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
TRACER.lock().unwrap().pre_vrelu::<F,T>(cores, data);
|
||||
TRACER.lock().unwrap().pre_vrelu::<F, T>(cores, data);
|
||||
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
let core = cores.core(core_indx);
|
||||
@@ -575,17 +626,19 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vtanh(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vtanh_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
TRACER.lock().unwrap().pre_vtanh::<F,T>(cores, data);
|
||||
TRACER.lock().unwrap().pre_vtanh::<F, T>(cores, data);
|
||||
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
let core = cores.core(core_indx);
|
||||
@@ -603,17 +656,19 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vsigm(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub(super) fn vsigm_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
TRACER.lock().unwrap().pre_vsigm::<F,T>(cores, data);
|
||||
TRACER.lock().unwrap().pre_vsigm::<F, T>(cores, data);
|
||||
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
let core = cores.core(core_indx);
|
||||
@@ -629,17 +684,22 @@ where
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vsoftmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
panic!("You are calling a placeholder, the real call is the generic version");
|
||||
}
|
||||
|
||||
pub(super) fn vsoftmax_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
#[inline(never)]
|
||||
pub(super) fn vsoftmax_impl<F, T>(
|
||||
cores: &mut CPU,
|
||||
data: InstructionData,
|
||||
) -> Result<InstructionStatus>
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
TRACER.lock().unwrap().pre_vsoftmax::<F,T>(cores, data);
|
||||
TRACER.lock().unwrap().pre_vsoftmax::<F, T>(cores, data);
|
||||
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
let core = cores.core(core_indx);
|
||||
@@ -656,27 +716,29 @@ where
|
||||
.reduce(|a, b| if a > b { a } else { b })
|
||||
.unwrap();
|
||||
let exp_values: Vec<F> = load1.iter().map(|&a| (a - max_val).exp()).collect();
|
||||
let sum = exp_values
|
||||
.iter()
|
||||
.copied()
|
||||
.reduce(|a, b| a + b)
|
||||
.unwrap();
|
||||
ensure!(sum > 0.0.into(), "vsoftmax normalization sum must be positive");
|
||||
let sum = exp_values.iter().copied().reduce(|a, b| a + b).unwrap();
|
||||
ensure!(
|
||||
sum > 0.0.into(),
|
||||
"vsoftmax normalization sum must be positive"
|
||||
);
|
||||
let res: Vec<F> = exp_values.iter().map(|&a| a / sum).collect();
|
||||
let res_up: Cow<[T]> = res.as_slice().up();
|
||||
core.execute_store(rd_val, res_up.as_ref());
|
||||
TRACER.lock().unwrap().post_vsoftmax::<F,T>(cores, data);
|
||||
TRACER.lock().unwrap().post_vsoftmax::<F, T>(cores, data);
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vrsu(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn vrsl(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
todo!()
|
||||
}
|
||||
@@ -684,6 +746,7 @@ pub fn vrsl(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
///////////////////////////////////////////////////////////////
|
||||
///Communication/synchronization Instructions/////////////////
|
||||
///////////////////////////////////////////////////////////////
|
||||
#[inline(never)]
|
||||
pub fn ld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_ld(cores, data);
|
||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||
@@ -700,6 +763,7 @@ pub fn ld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn st(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_st(cores, data);
|
||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||
@@ -716,6 +780,7 @@ pub fn st(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn lldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_lldi(cores, data);
|
||||
let (core, rd, imm) = data.get_core_rd_imm();
|
||||
@@ -732,6 +797,7 @@ pub fn lldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_lmv(cores, data);
|
||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||
@@ -748,20 +814,32 @@ pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
|
||||
Ok(InstructionStatus::Completed)
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn isa_send(functor : usize) -> bool{
|
||||
(send as *const () as usize) == functor
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn send(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_send(cores, data);
|
||||
Ok(InstructionStatus::Sending(data))
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn isa_recv(functor : usize) -> bool{
|
||||
(recv as *const () as usize) == functor
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn recv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
TRACER.lock().unwrap().pre_recv(cores, data);
|
||||
Ok(InstructionStatus::Reciving(data))
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn wait(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
Ok(InstructionStatus::Waiting(data))
|
||||
}
|
||||
|
||||
#[inline(never)]
|
||||
pub fn sync(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
|
||||
Ok(InstructionStatus::Sync(data))
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ pub mod helper;
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct Instruction {
|
||||
pub data: InstructionData,
|
||||
functor: InstructionType,
|
||||
pub functor: InstructionType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
|
||||
@@ -567,7 +567,7 @@ fn json_to_send(
|
||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||
inst_data_builder
|
||||
.set_rd(rd)
|
||||
.set_imm_core(core)
|
||||
.set_imm_core(core + 1)
|
||||
.set_imm_len(size)
|
||||
.set_offset_select(offset_select)
|
||||
.set_offset_value(offset_value);
|
||||
@@ -588,7 +588,7 @@ fn json_to_recv(
|
||||
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
|
||||
inst_data_builder
|
||||
.set_rd(rd)
|
||||
.set_imm_core(core)
|
||||
.set_imm_core(core + 1)
|
||||
.set_imm_len(size)
|
||||
.set_offset_select(offset_select)
|
||||
.set_offset_value(offset_value);
|
||||
|
||||
+15
-28
@@ -1,45 +1,32 @@
|
||||
use core::panic;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde_json::{Map, Value};
|
||||
use serde_json::Value;
|
||||
use std::{fs::File, io::BufReader};
|
||||
|
||||
use crate::{
|
||||
CoreInstructionsBuilder, Executable,
|
||||
cpu::{CPU, crossbar::{self, Crossbar}},
|
||||
instruction_set::{
|
||||
InstructionsBuilder,
|
||||
instruction_data::{self, InstructionData, InstructionDataBuilder},
|
||||
},
|
||||
json_to_instruction::{self, json_isa},
|
||||
memory_manager::type_traits::TryToUsize,
|
||||
cpu::{CPU, crossbar::Crossbar},
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
|
||||
json_to_instruction::json_isa,
|
||||
};
|
||||
|
||||
|
||||
pub fn json_to_executor<'a>(
|
||||
pub fn json_to_executor<'a, 'b>(
|
||||
config: Value,
|
||||
mut cores: impl Iterator<Item = &'a Value>,
|
||||
crossbars : Vec<Vec<&'a Crossbar>>
|
||||
cores: &'b mut Vec<BufReader<File>>,
|
||||
crossbars: Vec<Vec<&'a Crossbar>>,
|
||||
) -> Executable<'a> {
|
||||
let cell_precision = config.get("cell_precision").unwrap().as_i64().unwrap() as i32;
|
||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32 - 1;
|
||||
let xbar_count = config.get("xbar_array_count").unwrap().as_i64().unwrap() as i32;
|
||||
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
|
||||
let rows_crossbar = xbar_size[0].as_i64().unwrap() as i32;
|
||||
let column_corssbar = xbar_size[1].as_i64().unwrap() as i32;
|
||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
|
||||
|
||||
let mut cpu = CPU::new(core_cnt, crossbars);
|
||||
let cpu = CPU::new(core_cnt, crossbars);
|
||||
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
|
||||
cores.next();
|
||||
for core_indx in 1..=core_cnt {
|
||||
for (external_core_indx, json_core_reader) in cores.iter_mut().enumerate() {
|
||||
let core_indx = external_core_indx as i32 + 1;
|
||||
let mut insts_builder = InstructionsBuilder::new();
|
||||
let mut inst_data_builder = InstructionDataBuilder::new();
|
||||
inst_data_builder.set_core_indx(core_indx).fix_core_indx();
|
||||
let json_core = cores
|
||||
.next()
|
||||
.unwrap_or_else(|| panic!("cores files less than {}", core_indx ));
|
||||
let json_core: Value = serde_json::from_reader(json_core_reader)
|
||||
.unwrap_or_else(|err| panic!("failed to parse core{}: {}", external_core_indx, err));
|
||||
let json_core_insts = json_core
|
||||
.as_array()
|
||||
.unwrap_or_else(|| panic!("core{} has not a list of instruction", core_indx));
|
||||
.unwrap_or_else(|| panic!("core{} has not a list of instruction", external_core_indx));
|
||||
for json_inst in json_core_insts {
|
||||
json_isa::json_to_instruction(&mut insts_builder, &mut inst_data_builder, json_inst);
|
||||
}
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
mod json_isa;
|
||||
pub(crate) mod json_isa;
|
||||
pub mod json_to_executor;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::cmp::min;
|
||||
use std::fmt::Debug;
|
||||
|
||||
use anyhow::{Context, Result, bail, ensure};
|
||||
@@ -86,7 +87,7 @@ where {
|
||||
size,
|
||||
};
|
||||
if self.memory.len() < address + size {
|
||||
self.memory.resize((address + size) * 2, 0);
|
||||
self.memory.resize(min((address + size) * 2, u32::MAX as usize), 0);
|
||||
}
|
||||
self.load_requests.push(load_request);
|
||||
Ok(self)
|
||||
|
||||
@@ -55,15 +55,23 @@ pub trait HasSigm {
|
||||
|
||||
impl HasSigm for f32 {
|
||||
fn sigm(self) -> Self {
|
||||
let ex = self.exp();
|
||||
ex / (1.0 + ex)
|
||||
if self >= 0.0 {
|
||||
1.0 / (1.0 + (-self).exp())
|
||||
} else {
|
||||
let ex = self.exp();
|
||||
ex / (1.0 + ex)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HasSigm for f64 {
|
||||
fn sigm(self) -> Self {
|
||||
let ex = self.exp();
|
||||
ex / (1.0 + ex)
|
||||
if self >= 0.0 {
|
||||
1.0 / (1.0 + (-self).exp())
|
||||
} else {
|
||||
let ex = self.exp();
|
||||
ex / (1.0 + ex)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
#![allow(unused)]
|
||||
|
||||
use std::time::{Duration, SystemTime};
|
||||
use anyhow::{Result, bail};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
cpu::CPU,
|
||||
instruction_set::{Instruction, InstructionStatus, Instructions, isa::functor_to_name},
|
||||
instruction_set::{
|
||||
Instruction, InstructionStatus, Instructions,
|
||||
isa::{NAMES, functor_to_name, isa_recv, isa_send},
|
||||
},
|
||||
memory_manager::type_traits::TryToUsize,
|
||||
send_recv::{SendRecv, handle_send_recv},
|
||||
tracing::TRACER,
|
||||
};
|
||||
pub mod binary_to_instruction;
|
||||
pub mod cpu;
|
||||
pub mod instruction_set;
|
||||
pub mod json_to_instruction;
|
||||
@@ -80,6 +88,11 @@ pub struct Executable<'a> {
|
||||
send_recv: SendRecv,
|
||||
}
|
||||
|
||||
struct DeadlockInfo {
|
||||
cycle: String,
|
||||
states: String,
|
||||
}
|
||||
|
||||
fn print_status(core_instructions: &[CoreInstructions]) {
|
||||
let mut tot_instructions = 0;
|
||||
let mut progress = 0;
|
||||
@@ -111,7 +124,7 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute<'b>(&'b mut self)
|
||||
pub fn execute<'b>(&'b mut self) -> Result<()>
|
||||
where
|
||||
'a: 'b,
|
||||
{
|
||||
@@ -144,8 +157,15 @@ impl<'a> Executable<'a> {
|
||||
cpu_progressed = 0;
|
||||
*program_counter += 1;
|
||||
}
|
||||
if (now.elapsed().unwrap() > Duration::from_secs(1)) {
|
||||
print_status(&cores_instructions);
|
||||
if (now.elapsed().unwrap() > Duration::from_secs(5)) {
|
||||
print_status(cores_instructions);
|
||||
if let Some(deadlock) = detect_deadlock(cores_instructions) {
|
||||
bail!(
|
||||
"Deadlock cycle detected: {} [{}]",
|
||||
deadlock.cycle,
|
||||
deadlock.states
|
||||
);
|
||||
}
|
||||
now = SystemTime::now();
|
||||
}
|
||||
}
|
||||
@@ -169,6 +189,24 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
}
|
||||
print_status(cores_instructions);
|
||||
|
||||
if let Some(deadlock) = detect_deadlock(cores_instructions) {
|
||||
bail!(
|
||||
"Deadlock cycle detected: {} [{}]",
|
||||
deadlock.cycle,
|
||||
deadlock.states
|
||||
);
|
||||
}
|
||||
if cores_instructions
|
||||
.iter()
|
||||
.any(|core_inst| core_inst.program_counter < core_inst.instructions.len())
|
||||
{
|
||||
bail!("Execution stalled with unfinished instructions");
|
||||
}
|
||||
|
||||
#[cfg(feature = "profile_time")]
|
||||
TRACER.lock().unwrap().report();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn cpu(&self) -> &CPU<'a> {
|
||||
@@ -190,6 +228,125 @@ impl<'a> Executable<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockInfo> {
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
enum CoreState {
|
||||
SendingTo(i32, i32),
|
||||
ReceivingFrom(i32, i32),
|
||||
Working,
|
||||
Halted,
|
||||
}
|
||||
|
||||
let mut states = HashMap::new();
|
||||
|
||||
for core_inst in cores_instructions.iter() {
|
||||
if core_inst.program_counter >= core_inst.instructions.len() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Instruction { data, functor } = core_inst.instructions[core_inst.program_counter];
|
||||
let functor_address = functor as usize;
|
||||
|
||||
let (this_core, target_core) = data.get_core_immcore();
|
||||
|
||||
if isa_recv(functor_address) {
|
||||
states.insert(this_core, CoreState::ReceivingFrom(target_core, data.imm_len()));
|
||||
} else if isa_send(functor_address) {
|
||||
states.insert(this_core, CoreState::SendingTo(target_core, data.imm_len()));
|
||||
} else {
|
||||
states.insert(this_core, CoreState::Working);
|
||||
}
|
||||
}
|
||||
|
||||
let mut wait_for = HashMap::new();
|
||||
|
||||
for (&core_id, state) in states.iter() {
|
||||
match state {
|
||||
CoreState::SendingTo(target_core, size) => {
|
||||
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
||||
if target_state != &CoreState::ReceivingFrom(core_id, *size) {
|
||||
wait_for.insert(core_id, *target_core);
|
||||
}
|
||||
}
|
||||
CoreState::ReceivingFrom(target_core, size) => {
|
||||
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
|
||||
if target_state != &CoreState::SendingTo(core_id, *size) {
|
||||
wait_for.insert(core_id, *target_core);
|
||||
}
|
||||
}
|
||||
CoreState::Working | CoreState::Halted => {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut visited = HashSet::new();
|
||||
|
||||
for &start_core in wait_for.keys() {
|
||||
if visited.contains(&start_core) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut path = Vec::new();
|
||||
let mut current_core = start_core;
|
||||
let mut in_path = HashSet::new();
|
||||
|
||||
while let Some(&waiting_for) = wait_for.get(¤t_core) {
|
||||
path.push(current_core);
|
||||
in_path.insert(current_core);
|
||||
visited.insert(current_core);
|
||||
|
||||
// Found a closed loop!
|
||||
if in_path.contains(&waiting_for) {
|
||||
let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap();
|
||||
let cycle = &path[cycle_start..];
|
||||
let format_core = |core: &i32| (core - 1).to_string();
|
||||
|
||||
let cycle_str = cycle
|
||||
.iter()
|
||||
.map(format_core)
|
||||
.collect::<Vec<_>>()
|
||||
.join(" -> ");
|
||||
|
||||
let cycle = cycle
|
||||
.iter()
|
||||
.copied()
|
||||
.chain(std::iter::once(waiting_for))
|
||||
.collect::<Vec<_>>();
|
||||
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for - 1);
|
||||
let states_msg = cycle
|
||||
.iter()
|
||||
.filter_map(|core| {
|
||||
states.get(core).map(|state| match state {
|
||||
CoreState::SendingTo(target, size) => {
|
||||
format!("core {} send {}B -> {}", core - 1, size, target - 1)
|
||||
}
|
||||
CoreState::ReceivingFrom(source, size) => {
|
||||
format!("core {} recv {}B <- {}", core - 1, size, source - 1)
|
||||
}
|
||||
CoreState::Working => format!("core {} working", core - 1),
|
||||
CoreState::Halted => format!("core {} halted", core - 1),
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
|
||||
return Some(DeadlockInfo {
|
||||
cycle: cycle_msg,
|
||||
states: states_msg,
|
||||
});
|
||||
}
|
||||
|
||||
// Hit a known branch that didn't result in a cycle
|
||||
if visited.contains(&waiting_for) {
|
||||
break;
|
||||
}
|
||||
|
||||
current_core = waiting_for;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn handle_wait_sync<'a, 'b, 'c>(
|
||||
cpu: &'b mut CPU<'a>,
|
||||
core_instructions: &'c mut [CoreInstructions],
|
||||
|
||||
@@ -58,6 +58,20 @@ where 'a : 'b
|
||||
&& sender.internal_core == receiver.external_core
|
||||
&& receiver.internal_core == sender.external_core
|
||||
{
|
||||
{
|
||||
let sender = &mut core_instructions[sender.internal_core];
|
||||
let pc = sender.program_counter;
|
||||
let inst = sender.instructions.get(pc).unwrap();
|
||||
let data = inst.data;
|
||||
TRACER.lock().unwrap().pre_send(cpu, data);
|
||||
}
|
||||
{
|
||||
let recv = &mut core_instructions[receiver.internal_core];
|
||||
let pc = recv.program_counter;
|
||||
let inst = recv.instructions.get(pc).unwrap();
|
||||
let data = inst.data;
|
||||
TRACER.lock().unwrap().pre_recv(cpu, data);
|
||||
}
|
||||
let [sender_core, reciver_core] =
|
||||
cpu.get_multiple_cores([sender.internal_core, receiver.internal_core]);
|
||||
let memory = sender_core
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::{
|
||||
};
|
||||
use std::io::Write;
|
||||
|
||||
#[cfg(not(feature = "tracing"))]
|
||||
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
|
||||
impl Trace {
|
||||
///////////////////////////////////////////////////////////////
|
||||
/////////////////Scalar/register Instructions//////////////////
|
||||
|
||||
@@ -1,52 +1,32 @@
|
||||
mod tracing_isa;
|
||||
mod disable;
|
||||
mod pretty_print;
|
||||
use std::{fs::File, path::{ PathBuf}};
|
||||
#[cfg(feature = "profile_time")]
|
||||
mod profile;
|
||||
|
||||
#[cfg(feature = "profile_time")]
|
||||
use profile::Trace;
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
mod trace;
|
||||
#[cfg(feature = "tracing")]
|
||||
use trace::Trace;
|
||||
|
||||
use crate::Executable;
|
||||
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{LazyLock, Mutex};
|
||||
|
||||
|
||||
use crate::Executable;
|
||||
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
|
||||
pub struct Trace {}
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
pub struct Trace {
|
||||
out_files : Vec<File>
|
||||
}
|
||||
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
|
||||
impl Trace {
|
||||
fn new() -> Self {
|
||||
Self { out_files : Vec::new()}
|
||||
Self {}
|
||||
}
|
||||
|
||||
|
||||
pub fn init(&mut self, num_core : usize , mut path : PathBuf) {
|
||||
path.pop();
|
||||
for i in 0..num_core {
|
||||
path.push(format!("TraceCore{}", i));
|
||||
let file = File::create(&path).expect("Can not create file");
|
||||
self.out_files.push(file);
|
||||
path.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "tracing"))]
|
||||
pub struct Trace {
|
||||
pub fn init(&mut self, num_core: usize, path: PathBuf) {}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(not(feature = "tracing"))]
|
||||
impl Trace {
|
||||
fn new() -> Self {
|
||||
Self { }
|
||||
}
|
||||
|
||||
|
||||
pub fn init(&mut self, num_core : usize, path : PathBuf ) {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub static TRACER: LazyLock<Mutex<Trace>> = LazyLock::new(|| { Trace::new().into()});
|
||||
|
||||
pub static TRACER: LazyLock<Mutex<Trace>> = LazyLock::new(|| Trace::new().into());
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
use std::{collections::HashMap, path::PathBuf, time::Instant};
|
||||
|
||||
use crate::tracing::profile::profile_analysis::{
|
||||
analyze_timings, generate_interactive_report, print_textual_report,
|
||||
};
|
||||
|
||||
pub mod profile_analysis;
|
||||
pub mod profile_isa;
|
||||
|
||||
pub struct Trace {
|
||||
instruction_times: HashMap<String, Vec<(u128,u128)>>,
|
||||
core_start_time: HashMap<usize, Option<Instant>>,
|
||||
start_time: Instant,
|
||||
}
|
||||
|
||||
impl Trace {
|
||||
pub fn new() -> Self {
|
||||
let mut instruction_times = HashMap::new();
|
||||
instruction_times.insert("sldi".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("sld".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("sadd".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("ssub".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("smul".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("saddi".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("smuli".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("setbw".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("mvmul".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("vvadd".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("vvsub".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("vvmul".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("vvdmul".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("vvmax".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("vvsll".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("vvsra".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("vavg".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("vrelu".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("vtanh".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("vsigm".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("vsoftmax".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("vmv".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("vrsu".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("vrsl".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("ld".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("st".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("lldi".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("lmv".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("send".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("recv".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("wait".to_string(), Vec::with_capacity(20000));
|
||||
instruction_times.insert("sync".to_string(), Vec::with_capacity(20000));
|
||||
Self {
|
||||
instruction_times,
|
||||
core_start_time: HashMap::new(),
|
||||
start_time: Instant::now()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(&mut self, num_core: usize, path: PathBuf) {
|
||||
for i in 0..num_core {
|
||||
self.core_start_time.insert(i, None);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn report(&self) {
|
||||
let res = analyze_timings(&self.instruction_times);
|
||||
print_textual_report(&res);
|
||||
generate_interactive_report(
|
||||
&self.instruction_times,
|
||||
&["mvmul", "recv"],
|
||||
"/tmp/report.html",
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
use comfy_table::{Cell, Table, modifiers::UTF8_ROUND_CORNERS, presets::UTF8_FULL};
|
||||
use statrs::statistics::{Data, Distribution, Max, Min, OrderStatistics};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct InstructionStats {
|
||||
pub name: String,
|
||||
pub count: usize,
|
||||
pub total_time: u128,
|
||||
pub min: f64,
|
||||
pub max: f64,
|
||||
pub mean: f64,
|
||||
pub median: f64,
|
||||
pub std_dev: f64,
|
||||
pub cv: f64,
|
||||
pub p95: f64,
|
||||
pub p99: f64,
|
||||
pub skewness: f64,
|
||||
pub kurtosis: f64,
|
||||
}
|
||||
|
||||
fn format_time(ns: f64) -> String {
|
||||
if ns.is_nan() {
|
||||
return "NaN".to_string();
|
||||
}
|
||||
|
||||
if ns >= 1_000_000_000.0 {
|
||||
format!("{:.2} s", ns / 1_000_000_000.0)
|
||||
} else if ns >= 1_000_000.0 {
|
||||
format!("{:.2} ms", ns / 1_000_000.0)
|
||||
} else if ns >= 1_000.0 {
|
||||
format!("{:.2} µs", ns / 1_000.0)
|
||||
} else {
|
||||
format!("{:.2} ns", ns)
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_skewness_kurtosis(times: &[f64], mean: f64, std_dev: f64) -> (f64, f64) {
|
||||
let n = times.len() as f64;
|
||||
|
||||
if n < 4.0 || std_dev == 0.0 {
|
||||
return (f64::NAN, f64::NAN);
|
||||
}
|
||||
|
||||
let mut sum_m3 = 0.0;
|
||||
let mut sum_m4 = 0.0;
|
||||
|
||||
for &x in times {
|
||||
let deviation = x - mean;
|
||||
sum_m3 += deviation.powi(3);
|
||||
sum_m4 += deviation.powi(4);
|
||||
}
|
||||
|
||||
let m3 = sum_m3 / n;
|
||||
let m4 = sum_m4 / n;
|
||||
|
||||
let skewness = m3 / std_dev.powi(3);
|
||||
let kurtosis = (m4 / std_dev.powi(4)) - 3.0;
|
||||
|
||||
(skewness, kurtosis)
|
||||
}
|
||||
|
||||
pub fn analyze_timings(timings: &HashMap<String, Vec<(u128, u128)>>) -> Vec<InstructionStats> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
for (instruction, times) in timings {
|
||||
let count = times.len();
|
||||
if count == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Extract ONLY the duration (the second element of the tuple) for stats
|
||||
let durations: Vec<u128> = times.iter().map(|&(_, duration)| duration).collect();
|
||||
let total_time: u128 = durations.iter().sum();
|
||||
|
||||
let f64_times: Vec<f64> = durations.iter().map(|&t| t as f64).collect();
|
||||
let mut data = Data::new(f64_times.clone());
|
||||
|
||||
let mean = data.mean().unwrap_or(0.0);
|
||||
let std_dev = data.std_dev().unwrap_or(0.0);
|
||||
let cv = if mean > 0.0 { std_dev / mean } else { 0.0 };
|
||||
|
||||
let (skewness, kurtosis) = calculate_skewness_kurtosis(&f64_times, mean, std_dev);
|
||||
|
||||
results.push(InstructionStats {
|
||||
name: instruction.clone(),
|
||||
count,
|
||||
total_time,
|
||||
min: data.min(),
|
||||
max: data.max(),
|
||||
mean,
|
||||
median: data.median(),
|
||||
std_dev,
|
||||
cv,
|
||||
p95: data.percentile(95),
|
||||
p99: data.percentile(99),
|
||||
skewness,
|
||||
kurtosis,
|
||||
});
|
||||
}
|
||||
|
||||
results.sort_by(|a, b| b.mean.partial_cmp(&a.mean).unwrap());
|
||||
results
|
||||
}
|
||||
|
||||
pub fn print_textual_report(stats: &[InstructionStats]) {
|
||||
let mut table = Table::new();
|
||||
table
|
||||
.load_preset(UTF8_FULL)
|
||||
.apply_modifier(UTF8_ROUND_CORNERS)
|
||||
.set_header(vec![
|
||||
"Instruction",
|
||||
"Count",
|
||||
"Total Time",
|
||||
"Mean",
|
||||
"Median",
|
||||
"Min",
|
||||
"Max",
|
||||
"P95",
|
||||
"P99",
|
||||
"StdDev",
|
||||
"CV",
|
||||
"Skewness",
|
||||
"Kurtosis",
|
||||
]);
|
||||
|
||||
for stat in stats {
|
||||
table.add_row(vec![
|
||||
Cell::new(&stat.name),
|
||||
Cell::new(stat.count.to_string()),
|
||||
Cell::new(format_time(stat.total_time as f64)), // Cast u128 to f64 for formatting
|
||||
Cell::new(format_time(stat.mean)),
|
||||
Cell::new(format_time(stat.median)),
|
||||
Cell::new(format_time(stat.min)),
|
||||
Cell::new(format_time(stat.max)),
|
||||
Cell::new(format_time(stat.p95)),
|
||||
Cell::new(format_time(stat.p99)),
|
||||
Cell::new(format_time(stat.std_dev)),
|
||||
Cell::new(format!("{:.3}", stat.cv)),
|
||||
Cell::new(format!("{:.2}", stat.skewness)),
|
||||
Cell::new(format!("{:.2}", stat.kurtosis)),
|
||||
]);
|
||||
}
|
||||
|
||||
println!("{table}");
|
||||
}
|
||||
|
||||
|
||||
pub fn generate_interactive_report(
|
||||
timings: &HashMap<String, Vec<(u128, u128)>>,
|
||||
instructions_to_plot: &[&str], // <-- NEW: Only plot these
|
||||
file_path: &str,
|
||||
) {
|
||||
|
||||
use plotly::common::{Mode, Marker, Line};
|
||||
use plotly::layout::{Axis, Layout};
|
||||
use plotly::{Plot, Scatter};
|
||||
use std::collections::HashMap;
|
||||
let mut plot = Plot::new();
|
||||
|
||||
for &instruction_name in instructions_to_plot {
|
||||
// Only proceed if the instruction exists in our timings map
|
||||
if let Some(times) = timings.get(instruction_name) {
|
||||
let x_axis: Vec<f64> = times.iter().map(|&(ts, _)| ts as f64).collect();
|
||||
let y_axis: Vec<f64> = times.iter().map(|&(_, dur)| dur as f64).collect();
|
||||
|
||||
let text_array: Vec<String> = times.iter()
|
||||
.map(|&(_, dur)| format_time(dur as f64))
|
||||
.collect();
|
||||
|
||||
let trace = Scatter::new(x_axis, y_axis)
|
||||
.name(instruction_name)
|
||||
.mode(Mode::LinesMarkers)
|
||||
.marker(Marker::new().size(4).opacity(0.6))
|
||||
.line(Line::new().width(1.0))
|
||||
.text_array(text_array)
|
||||
.hover_info(plotly::common::HoverInfo::All);
|
||||
|
||||
plot.add_trace(trace);
|
||||
}
|
||||
}
|
||||
|
||||
let layout = Layout::new()
|
||||
.title(plotly::common::Title::new("Simulator Timeline: Top Offenders"))
|
||||
.x_axis(Axis::new().title(plotly::common::Title::new("Absolute Time (ns)")))
|
||||
.y_axis(Axis::new().title(plotly::common::Title::new("Execution Duration")));
|
||||
|
||||
plot.set_layout(layout);
|
||||
plot.write_html(file_path);
|
||||
println!("🌐 Interactive timeline saved to {}", file_path);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,364 @@
|
||||
use crate::{
|
||||
cpu::CPU,
|
||||
instruction_set::instruction_data::InstructionData,
|
||||
memory_manager::{
|
||||
MemoryStorable,
|
||||
type_traits::{FromFloat, UpcastDestTraits, UpcastSlice},
|
||||
},
|
||||
tracing::Trace,
|
||||
utility::{add_offset_r1, add_offset_rd},
|
||||
};
|
||||
use std::io::Write;
|
||||
use std::time::Instant;
|
||||
|
||||
#[cfg(feature = "profile_time")]
|
||||
impl Trace {
|
||||
///////////////////////////////////////////////////////////////
|
||||
/////////////////Scalar/register Instructions//////////////////
|
||||
///////////////////////////////////////////////////////////////
|
||||
|
||||
fn pre_impl(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
let (core_indx, rd, imm) = data.get_core_rd_imm();
|
||||
let core_indx = core_indx as usize;
|
||||
if self.core_start_time.get(&core_indx).unwrap().is_none() {
|
||||
self.core_start_time.insert(core_indx, Some(Instant::now()));
|
||||
}
|
||||
}
|
||||
|
||||
fn post_impl(&mut self, cores: &mut CPU, data: InstructionData, name: &'static str) {
|
||||
let (core_indx, rd, imm) = data.get_core_rd_imm();
|
||||
let core_indx = core_indx as usize;
|
||||
let Self {
|
||||
instruction_times,
|
||||
core_start_time,
|
||||
start_time,
|
||||
} = self;
|
||||
let now = Instant::now();
|
||||
instruction_times
|
||||
.get_mut(name)
|
||||
.unwrap()
|
||||
.push((now.duration_since(*start_time).as_nanos(), now.duration_since(core_start_time[&core_indx].unwrap()).as_nanos()));
|
||||
self.core_start_time.insert(core_indx, None);
|
||||
}
|
||||
|
||||
pub fn pre_sldi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_sldi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.post_impl(cores, data, "sldi");
|
||||
}
|
||||
|
||||
pub fn pre_sld(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_sld(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.post_impl(cores, data, "sld");
|
||||
}
|
||||
|
||||
pub fn pre_sadd(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_sadd(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.post_impl(cores, data, "sadd");
|
||||
}
|
||||
|
||||
pub fn pre_ssub(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_ssub(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.post_impl(cores, data, "ssub");
|
||||
}
|
||||
|
||||
pub fn pre_smul(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_smul(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.post_impl(cores, data, "smul");
|
||||
}
|
||||
|
||||
pub fn pre_saddi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_saddi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.post_impl(cores, data, "saddi");
|
||||
}
|
||||
|
||||
pub fn pre_smuli(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_smuli(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.post_impl(cores, data, "smuli");
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
///////////////////Matrix/vector Instructions////////////////////
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
pub fn pre_setbw(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_setbw(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.post_impl(cores, data, "setbw");
|
||||
}
|
||||
|
||||
pub fn pre_mvm<F, M, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T> + UpcastSlice<M>,
|
||||
[M]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_mvm<F, M, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T> + UpcastSlice<M>,
|
||||
[M]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
self.post_impl(cores, data, "mvmul");
|
||||
}
|
||||
|
||||
pub fn pre_vvadd<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_vvadd<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
self.post_impl(cores, data, "vvadd");
|
||||
}
|
||||
|
||||
pub fn pre_vvsub<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_vvsub<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
self.post_impl(cores, data, "vvsub");
|
||||
}
|
||||
|
||||
pub fn pre_vvmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_vvmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
self.post_impl(cores, data, "vvmul");
|
||||
}
|
||||
|
||||
pub fn pre_vvdmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_vvdmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
self.post_impl(cores, data, "vvdmul");
|
||||
}
|
||||
|
||||
pub fn pre_vvmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_vvmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
self.post_impl(cores, data, "vvmax");
|
||||
}
|
||||
|
||||
pub fn pre_vavg<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_vavg<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
self.post_impl(cores, data, "vavg");
|
||||
}
|
||||
|
||||
pub fn pre_vrelu<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_vrelu<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
self.post_impl(cores, data, "vrelu");
|
||||
}
|
||||
|
||||
pub fn pre_vtanh<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_vtanh<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
self.post_impl(cores, data, "vtanh");
|
||||
}
|
||||
|
||||
pub fn pre_vsigm<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_vsigm<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
self.post_impl(cores, data, "vsigm");
|
||||
}
|
||||
|
||||
pub fn pre_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
|
||||
where
|
||||
[F]: UpcastSlice<T>,
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
|
||||
{
|
||||
self.post_impl(cores, data, "vsoftmax");
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////
|
||||
/////Communication/synchronization Instructions/////////////////
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
pub fn pre_ld(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
pub fn post_ld(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.post_impl(cores, data, "ld");
|
||||
}
|
||||
|
||||
pub fn pre_st(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_st(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.post_impl(cores, data, "st");
|
||||
}
|
||||
|
||||
pub fn pre_lldi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_lldi(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.post_impl(cores, data, "lldi");
|
||||
}
|
||||
|
||||
pub fn pre_lmv(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_lmv(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.post_impl(cores, data, "lmv");
|
||||
}
|
||||
|
||||
pub fn pre_send(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_send(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.post_impl(cores, data, "send");
|
||||
}
|
||||
|
||||
pub fn pre_recv(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.pre_impl(cores, data);
|
||||
}
|
||||
|
||||
pub fn post_recv(&mut self, cores: &mut CPU, data: InstructionData) {
|
||||
self.post_impl(cores, data, "recv");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
use std::{fs::File, path::PathBuf};
|
||||
|
||||
pub mod pretty_print;
|
||||
pub mod tracing_isa;
|
||||
|
||||
pub struct Trace {
|
||||
out_files: Vec<File>,
|
||||
}
|
||||
|
||||
|
||||
impl Trace {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
out_files: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn init(&mut self, num_core: usize, mut path: PathBuf) {
|
||||
path.pop();
|
||||
for i in 0..num_core {
|
||||
path.push(format!("TraceCore{}", i));
|
||||
let file = File::create(&path).expect("Can not create file");
|
||||
self.out_files.push(file);
|
||||
path.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+1
-10
@@ -1,4 +1,4 @@
|
||||
use crate::tracing::pretty_print;
|
||||
use crate::{tracing::trace::pretty_print, utility::add_offset_r2};
|
||||
use std::fs::File;
|
||||
|
||||
use crate::{
|
||||
@@ -13,7 +13,6 @@ use crate::{
|
||||
};
|
||||
use std::io::Write;
|
||||
|
||||
#[cfg(feature = "tracing")]
|
||||
impl Trace {
|
||||
///////////////////////////////////////////////////////////////
|
||||
/////////////////Scalar/register Instructions//////////////////
|
||||
@@ -284,7 +283,6 @@ impl Trace {
|
||||
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
use crate::tracing::pretty_print;
|
||||
|
||||
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
|
||||
let file: &mut File = self
|
||||
@@ -358,8 +356,6 @@ impl Trace {
|
||||
T: UpcastDestTraits<T> + MemoryStorable,
|
||||
F: UpcastDestTraits<F> + MemoryStorable,
|
||||
{
|
||||
use crate::{tracing::pretty_print, utility::add_offset_r2};
|
||||
|
||||
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
let file: &mut File = self
|
||||
@@ -990,8 +986,6 @@ impl Trace {
|
||||
/////////////////////////////////////////////////////////////////
|
||||
|
||||
pub fn ld_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
|
||||
use crate::tracing::pretty_print;
|
||||
|
||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
let file: &mut File = self
|
||||
@@ -1044,8 +1038,6 @@ impl Trace {
|
||||
}
|
||||
|
||||
pub fn st_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
|
||||
use crate::tracing::pretty_print;
|
||||
|
||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
let file: &mut File = self
|
||||
@@ -1138,7 +1130,6 @@ impl Trace {
|
||||
}
|
||||
|
||||
fn lmv_impl (&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
|
||||
use crate::tracing::pretty_print;
|
||||
|
||||
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
|
||||
data.get_core_rd_r1_r2_immlen_offset();
|
||||
@@ -1,7 +1,45 @@
|
||||
use anyhow::{Result,Context};
|
||||
use std::{fmt::Debug, mem::transmute};
|
||||
|
||||
use crate::memory_manager::type_traits::TryToUsize;
|
||||
|
||||
pub trait AddressArg {
|
||||
fn to_address_usize(self) -> Result<usize>;
|
||||
}
|
||||
|
||||
impl AddressArg for usize {
|
||||
fn to_address_usize(self) -> Result<usize> {
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl AddressArg for u32 {
|
||||
fn to_address_usize(self) -> Result<usize> {
|
||||
Ok(self as usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl AddressArg for u64 {
|
||||
fn to_address_usize(self) -> Result<usize> {
|
||||
usize::try_from(self).context("address does not fit in usize")
|
||||
}
|
||||
}
|
||||
|
||||
impl AddressArg for i32 {
|
||||
fn to_address_usize(self) -> Result<usize> {
|
||||
Ok(self as u32 as usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl AddressArg for i64 {
|
||||
fn to_address_usize(self) -> Result<usize> {
|
||||
usize::try_from(self).context("address can not be negative")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn address_to_usize(address: i32) -> usize {
|
||||
address as u32 as usize
|
||||
}
|
||||
|
||||
fn add_offset_impl(address: usize, offset_select : i32, offset_value : i32, id:i32) -> usize{
|
||||
assert!(offset_select == 1 || offset_select == 2 || offset_select == 4 || offset_value == 0, "offset_select not a bit field");
|
||||
@@ -14,21 +52,21 @@ fn add_offset_impl(address: usize, offset_select : i32, offset_value : i32, id:i
|
||||
}
|
||||
|
||||
|
||||
pub fn add_offset_rd(address: impl TryToUsize, offset_select : i32, offset_value : i32) -> usize
|
||||
pub fn add_offset_rd(address: i32, offset_select : i32, offset_value : i32) -> usize
|
||||
{
|
||||
let address = address.try_into().expect("address can not be negative");
|
||||
let address = address_to_usize(address);
|
||||
add_offset_impl(address, offset_select, offset_value, 4)
|
||||
}
|
||||
|
||||
pub fn add_offset_r1(address: impl TryToUsize, offset_select : i32, offset_value : i32) -> usize
|
||||
pub fn add_offset_r1(address: i32, offset_select : i32, offset_value : i32) -> usize
|
||||
{
|
||||
let address = address.try_into().expect("address can not be negative");
|
||||
let address = address_to_usize(address);
|
||||
add_offset_impl(address, offset_select, offset_value, 1)
|
||||
}
|
||||
|
||||
pub fn add_offset_r2(address: impl TryToUsize, offset_select : i32, offset_value : i32) -> usize
|
||||
pub fn add_offset_r2(address: i32, offset_select : i32, offset_value : i32) -> usize
|
||||
{
|
||||
let address = address.try_into().expect("address can not be negative");
|
||||
let address = address_to_usize(address);
|
||||
add_offset_impl(address, offset_select, offset_value, 2)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
use std::path::Path;
|
||||
|
||||
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
|
||||
use pimcore::{
|
||||
Executable,
|
||||
cpu::crossbar::Crossbar,
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||
memory_manager::CoreMemory,
|
||||
};
|
||||
|
||||
fn simple_read(path: &Path) -> Vec<f32> {
|
||||
if !path.exists() {
|
||||
@@ -17,14 +22,12 @@ fn simple_read(path: &Path) -> Vec<f32> {
|
||||
fn mvmul_f32(err: &str)
|
||||
where
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
cpu.reserve_crossbar(1, 1024 * size_of::<f32>(), 1024);
|
||||
let (memory, crossbars) = cpu.host().get_memory_crossbar();
|
||||
let matrix = simple_read(Path::new("B.txt")) ;
|
||||
|
||||
|
||||
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
|
||||
let vector = simple_read(Path::new("A.txt"));
|
||||
let matrix = simple_read(Path::new("tests/B.txt"));
|
||||
let mut crossbar = Crossbar::new(1024 * size_of::<f32>(), 1024, CoreMemory::new());
|
||||
crossbar.execute_store(&matrix).unwrap();
|
||||
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
|
||||
let (memory, _) = cpu.host().get_memory_crossbar();
|
||||
let vector = simple_read(Path::new("tests/A.txt"));
|
||||
memory.execute_store(0, &vector).unwrap();
|
||||
|
||||
let mut inst_builder = InstructionsBuilder::new();
|
||||
@@ -57,7 +60,7 @@ where
|
||||
.cpu_mut()
|
||||
.host()
|
||||
.load::<f32>(1024 * size_of::<f32>(), 1024*size_of::<f32>()).unwrap()[0].iter().zip(
|
||||
simple_read(Path::new("X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
|
||||
simple_read(Path::new("tests/X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
|
||||
"Wrong result for {}",
|
||||
err
|
||||
);
|
||||
@@ -69,5 +72,3 @@ fn mvmul_big_test() {
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
use pimcore::cpu::CPU;
|
||||
|
||||
pub fn empty_cpu(num_cores: usize) -> CPU<'static> {
|
||||
CPU::new(num_cores, vec![Vec::new(); num_cores + 1])
|
||||
}
|
||||
@@ -1,51 +1,103 @@
|
||||
use std::{fs, io::BufReader, path::Path};
|
||||
use std::{
|
||||
fs::{self, File},
|
||||
io::BufReader,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use pimcore::json_to_instruction::json_to_executor;
|
||||
use pimcore::{
|
||||
cpu::crossbar::Crossbar,
|
||||
json_to_instruction::json_to_executor,
|
||||
memory_manager::CoreMemory,
|
||||
};
|
||||
use serde_json::Value;
|
||||
|
||||
fn collect_json_from_subfolders<P: AsRef<Path>>(root: P) -> Result<Vec<(Value, Vec<Value>)>> {
|
||||
fn collect_examples<P: AsRef<Path>>(root: P) -> Result<Vec<PathBuf>> {
|
||||
let mut result = Vec::new();
|
||||
for entry in fs::read_dir(root)? {
|
||||
let entry = entry.context("Root not found")?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
let mut cores = Vec::new();
|
||||
let mut config: Option<Value> = None;
|
||||
for sub_entry in fs::read_dir(&path)
|
||||
.with_context(|| format!("File {} not readable", path.display()))?
|
||||
{
|
||||
let sub_entry =
|
||||
sub_entry.with_context(|| format!("File {} not readable", path.display()))?;
|
||||
let sub_path = sub_entry.path();
|
||||
if sub_path.is_file()
|
||||
&& sub_path.extension().and_then(|s| s.to_str()) == Some("json")
|
||||
{
|
||||
let file = fs::File::open(&sub_path)
|
||||
.with_context(|| format!("Subpath {} not opened", sub_path.display()))?;
|
||||
let reader = BufReader::new(file);
|
||||
let val: Value = serde_json::from_reader(reader).with_context(|| format!(
|
||||
"Serde reader fail for subpath {}",
|
||||
sub_path.display()
|
||||
))?;
|
||||
if sub_path.file_name().unwrap() == "config.json" {
|
||||
config = Some(val);
|
||||
} else {
|
||||
cores.push(val);
|
||||
}
|
||||
}
|
||||
}
|
||||
result.push((config.unwrap(), cores));
|
||||
result.push(path);
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn core_sort_key(path: &Path) -> i32 {
|
||||
let stem = path.file_stem().unwrap().to_str().unwrap();
|
||||
stem[5..].parse::<i32>().unwrap()
|
||||
}
|
||||
|
||||
fn crossbar_sort_key(path: &Path) -> i32 {
|
||||
let stem = path.file_stem().unwrap().to_str().unwrap();
|
||||
stem[9..].parse::<i32>().unwrap()
|
||||
}
|
||||
|
||||
fn load_crossbars(folder: &Path, config: &Value) -> Result<Vec<Vec<Crossbar>>> {
|
||||
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
|
||||
let rows = xbar_size[0].as_i64().unwrap() as usize;
|
||||
let cols = xbar_size[1].as_i64().unwrap() as usize;
|
||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
|
||||
let mut owned_crossbars = Vec::with_capacity(core_cnt + 1);
|
||||
owned_crossbars.push(Vec::new());
|
||||
|
||||
for core_idx in 0..core_cnt {
|
||||
let core_folder = folder.join(format!("core_{core_idx}"));
|
||||
let mut core_crossbars = Vec::new();
|
||||
if core_folder.is_dir() {
|
||||
let mut paths: Vec<_> = fs::read_dir(&core_folder)?
|
||||
.map(|entry| entry.map(|entry| entry.path()))
|
||||
.collect::<std::io::Result<Vec<_>>>()?;
|
||||
paths.sort_by_cached_key(|path| crossbar_sort_key(path));
|
||||
for path in paths {
|
||||
if path.extension().and_then(|ext| ext.to_str()) != Some("bin") {
|
||||
continue;
|
||||
}
|
||||
let bytes = fs::read(&path)
|
||||
.with_context(|| format!("failed to read crossbar {}", path.display()))?;
|
||||
let mut crossbar = Crossbar::new(cols * 4, rows, CoreMemory::new());
|
||||
crossbar.execute_store(&bytes)?;
|
||||
core_crossbars.push(crossbar);
|
||||
}
|
||||
}
|
||||
owned_crossbars.push(core_crossbars);
|
||||
}
|
||||
Ok(owned_crossbars)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_folder_tester() {
|
||||
let examples = collect_json_from_subfolders("data").unwrap();
|
||||
for example in examples {
|
||||
let (config, cores) = example;
|
||||
json_to_executor::json_to_executor(config, cores.iter()).execute();
|
||||
let examples = collect_examples("tests/data").unwrap();
|
||||
for folder in examples {
|
||||
let config_path = folder.join("config.json");
|
||||
let config_file = File::open(&config_path).unwrap();
|
||||
let config: Value = serde_json::from_reader(BufReader::new(config_file)).unwrap();
|
||||
|
||||
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
|
||||
let mut core_paths: Vec<_> = fs::read_dir(&folder)
|
||||
.unwrap()
|
||||
.map(|entry| entry.unwrap().path())
|
||||
.filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("json"))
|
||||
.filter(|path| path.file_name().unwrap() != "config.json")
|
||||
.collect();
|
||||
core_paths.sort_by_cached_key(|path| core_sort_key(path));
|
||||
assert_eq!(core_paths.len(), core_cnt);
|
||||
|
||||
let mut core_readers: Vec<_> = core_paths
|
||||
.into_iter()
|
||||
.map(|path| BufReader::new(File::open(path).unwrap()))
|
||||
.collect();
|
||||
|
||||
let owned_crossbars = load_crossbars(&folder, &config).unwrap();
|
||||
let crossbars = owned_crossbars
|
||||
.iter()
|
||||
.map(|core_crossbars| core_crossbars.iter().collect())
|
||||
.collect();
|
||||
|
||||
let mut executable = json_to_executor::json_to_executor(config, &mut core_readers, crossbars);
|
||||
let memory = fs::read(folder.join("memory.bin")).unwrap();
|
||||
executable.cpu_mut().host().execute_store(0, &memory).unwrap();
|
||||
executable.execute();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
mod common;
|
||||
|
||||
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
|
||||
use pimcore::{
|
||||
Executable,
|
||||
instruction_set::{
|
||||
InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*,
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Function not found for the requested size") ]
|
||||
fn wrong_size_place_holder() {
|
||||
let cpu = CPU::new(0);
|
||||
let cpu = common::empty_cpu(0);
|
||||
let mut inst_builder = InstructionsBuilder::new();
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(0).fix_core_indx();
|
||||
@@ -30,7 +36,7 @@ fn wrong_size_place_holder() {
|
||||
|
||||
|
||||
fn place_holder(inst : InstructionType) {
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(0).fix_core_indx();
|
||||
inst(&mut cpu, idata_build.build()).unwrap();
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
mod common;
|
||||
|
||||
use pimcore::{
|
||||
Executable,
|
||||
cpu::CPU,
|
||||
cpu::crossbar::Crossbar,
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||
memory_manager::{MemoryStorable, type_traits::UpcastDestTraits},
|
||||
memory_manager::{CoreMemory, MemoryStorable, type_traits::UpcastDestTraits},
|
||||
};
|
||||
|
||||
/// VVADD Test
|
||||
@@ -11,7 +13,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -115,7 +117,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -219,7 +221,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -323,7 +325,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -420,7 +422,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
9.0.into(),
|
||||
2.0.into(),
|
||||
@@ -524,7 +526,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
9.0.into(),
|
||||
2.0.into(),
|
||||
@@ -562,6 +564,7 @@ where
|
||||
vavg,
|
||||
idata_build
|
||||
.set_rdr1r2(3, 1, 1)
|
||||
.set_offset_select(1)
|
||||
.set_imm_len(8 * size_of::<F>() as i32)
|
||||
.build(),
|
||||
);
|
||||
@@ -617,7 +620,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
(-9.0).into(),
|
||||
2.0.into(),
|
||||
@@ -717,7 +720,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
0.1.into(),
|
||||
0.2.into(),
|
||||
@@ -819,7 +822,7 @@ where
|
||||
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
let mut cpu = common::empty_cpu(0);
|
||||
let buff: [F; _] = [
|
||||
0.1.into(),
|
||||
0.2.into(),
|
||||
@@ -923,9 +926,6 @@ where
|
||||
M: From<f32> + std::fmt::Debug + PartialEq<M> + MemoryStorable,
|
||||
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
|
||||
{
|
||||
let mut cpu = CPU::new(0);
|
||||
cpu.reserve_crossbar(1, 4 * size_of::<M>(), 4);
|
||||
let (memory, crossbars) = cpu.host().get_memory_crossbar();
|
||||
let matrix: [M; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -944,7 +944,10 @@ where
|
||||
15.0.into(),
|
||||
16.0.into(),
|
||||
];
|
||||
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
|
||||
let mut crossbar = Crossbar::new(4 * size_of::<M>(), 4, CoreMemory::new());
|
||||
crossbar.execute_store(&matrix).unwrap();
|
||||
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
|
||||
let (memory, _) = cpu.host().get_memory_crossbar();
|
||||
let vector: [F; _] = [
|
||||
1.0.into(),
|
||||
2.0.into(),
|
||||
@@ -1054,5 +1057,3 @@ fn mvmul_test() {
|
||||
mvmul_test_generic::<f64,f64,f64>("mvmul<f64,f64,f64>",1);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
mod common;
|
||||
|
||||
use pimcore::{
|
||||
Executable, CoreInstructionsBuilder,
|
||||
cpu::CPU,
|
||||
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn ld_test() {
|
||||
let mut cpu = CPU::new(1);
|
||||
let mut cpu = common::empty_cpu(1);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@@ -41,7 +42,7 @@ fn ld_test() {
|
||||
|
||||
#[test]
|
||||
fn st_test() {
|
||||
let mut cpu = CPU::new(1);
|
||||
let mut cpu = common::empty_cpu(1);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@@ -76,7 +77,7 @@ fn st_test() {
|
||||
|
||||
#[test]
|
||||
fn lldi_test() {
|
||||
let cpu = CPU::new(1);
|
||||
let cpu = common::empty_cpu(1);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||
let mut inst_builder = InstructionsBuilder::new();
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
@@ -106,7 +107,7 @@ fn lldi_test() {
|
||||
|
||||
#[test]
|
||||
fn lmv_test() {
|
||||
let mut cpu = CPU::new(1);
|
||||
let mut cpu = common::empty_cpu(1);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@@ -148,7 +149,7 @@ fn lmv_test() {
|
||||
|
||||
#[test]
|
||||
fn simple_send_recv_test() {
|
||||
let mut cpu = CPU::new(2);
|
||||
let mut cpu = common::empty_cpu(2);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(2);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
|
||||
@@ -207,7 +208,7 @@ fn simple_send_recv_test() {
|
||||
|
||||
#[test]
|
||||
fn multiple_send_recv_test() {
|
||||
let mut cpu = CPU::new(4);
|
||||
let mut cpu = common::empty_cpu(4);
|
||||
let mut core_instruction_builder = CoreInstructionsBuilder::new(4);
|
||||
let buff: [f32; _] = [
|
||||
1.0, 1.0, 1.0, 1.0, 1.0
|
||||
@@ -226,7 +227,7 @@ fn multiple_send_recv_test() {
|
||||
];
|
||||
cpu.core(4).execute_store(0, &buff).unwrap();
|
||||
|
||||
let send_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, inst_builder: &mut InstructionsBuilder, from : i32, to : i32| {
|
||||
let send_inst = |inst_builder: &mut InstructionsBuilder, from: i32, to: i32| {
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(from).fix_core_indx();
|
||||
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
|
||||
@@ -240,7 +241,7 @@ fn multiple_send_recv_test() {
|
||||
);
|
||||
};
|
||||
|
||||
let recv_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, mut inst_builder: &mut InstructionsBuilder, to : i32, from : i32| {
|
||||
let recv_inst = |inst_builder: &mut InstructionsBuilder, to: i32, from: i32| {
|
||||
let mut idata_build = InstructionDataBuilder::new();
|
||||
idata_build.set_core_indx(to).fix_core_indx();
|
||||
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
|
||||
@@ -257,26 +258,26 @@ fn multiple_send_recv_test() {
|
||||
|
||||
|
||||
// 1 -> 3
|
||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,1, 3);
|
||||
send_inst(&mut inst_builder, 1, 3);
|
||||
core_instruction_builder.set_core(1, inst_builder.build());
|
||||
|
||||
// 2 -> 3
|
||||
// 2 <- 4
|
||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 3);
|
||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 4);
|
||||
send_inst(&mut inst_builder, 2, 3);
|
||||
recv_inst(&mut inst_builder, 2, 4);
|
||||
core_instruction_builder.set_core(2, inst_builder.build());
|
||||
|
||||
// 3 <- 2
|
||||
// 3 <- 4
|
||||
// 3 <- 1
|
||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 2);
|
||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 4);
|
||||
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 1);
|
||||
recv_inst(&mut inst_builder, 3, 2);
|
||||
recv_inst(&mut inst_builder, 3, 4);
|
||||
recv_inst(&mut inst_builder, 3, 1);
|
||||
core_instruction_builder.set_core(3, inst_builder.build());
|
||||
// 4 -> 2
|
||||
// 4 -> 3
|
||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 2);
|
||||
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 3);
|
||||
send_inst(&mut inst_builder, 4, 2);
|
||||
send_inst(&mut inst_builder, 4, 3);
|
||||
core_instruction_builder.set_core(4, inst_builder.build());
|
||||
|
||||
let mut executable = Executable::new(cpu, core_instruction_builder.build());
|
||||
|
||||
Submodule backend-simulators/pim/pimsim-nn updated: 3e3442b663...6d3b898e6b
@@ -10,6 +10,56 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
|
||||
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
|
||||
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
|
||||
|
||||
set(PIM_GENERATED_PATH_SHIM_TARGET "")
|
||||
get_filename_component(PIM_BIN_ROOT_NAME "${PIM_BIN_ROOT}" NAME)
|
||||
if (PIM_BIN_ROOT_NAME STREQUAL "raptor-external")
|
||||
get_filename_component(PIM_GENERATED_PATH_SHIM_ROOT "${PIM_BIN_ROOT}" DIRECTORY)
|
||||
set(PIM_GENERATED_PATH_SHIM_OUTPUTS)
|
||||
|
||||
function(add_pim_generated_path_shim relative_path)
|
||||
set(real_file "${PIM_BIN_ROOT}/${relative_path}")
|
||||
set(shim_file "${PIM_GENERATED_PATH_SHIM_ROOT}/${relative_path}")
|
||||
get_filename_component(shim_dir "${shim_file}" DIRECTORY)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT "${shim_file}"
|
||||
DEPENDS "${real_file}"
|
||||
COMMAND "${CMAKE_COMMAND}" -E make_directory "${shim_dir}"
|
||||
COMMAND "${CMAKE_COMMAND}" -E rm -f "${shim_file}"
|
||||
COMMAND "${CMAKE_COMMAND}" -E create_symlink "${real_file}" "${shim_file}"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
list(APPEND PIM_GENERATED_PATH_SHIM_OUTPUTS "${shim_file}")
|
||||
set(PIM_GENERATED_PATH_SHIM_OUTPUTS "${PIM_GENERATED_PATH_SHIM_OUTPUTS}" PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
file(GLOB_RECURSE pim_generated_path_scan_sources
|
||||
CONFIGURE_DEPENDS
|
||||
"${PIM_SRC_ROOT}/*.cpp"
|
||||
"${PIM_SRC_ROOT}/*.hpp"
|
||||
)
|
||||
|
||||
set(pim_generated_path_shims)
|
||||
foreach (source_file IN LISTS pim_generated_path_scan_sources)
|
||||
file(READ "${source_file}" source_contents)
|
||||
string(REGEX MATCHALL "#include \"src/Accelerators/PIM/[^\"]+\\.inc\"" source_inc_matches "${source_contents}")
|
||||
|
||||
foreach (inc_match IN LISTS source_inc_matches)
|
||||
string(REGEX REPLACE "^#include \"src/Accelerators/PIM/(.+)\"$" "\\1" relative_inc_path "${inc_match}")
|
||||
list(APPEND pim_generated_path_shims "${relative_inc_path}")
|
||||
endforeach ()
|
||||
endforeach ()
|
||||
|
||||
list(REMOVE_DUPLICATES pim_generated_path_shims)
|
||||
foreach (relative_inc_path IN LISTS pim_generated_path_shims)
|
||||
add_pim_generated_path_shim("${relative_inc_path}")
|
||||
endforeach ()
|
||||
|
||||
add_custom_target(OMPimGeneratedPathShims DEPENDS ${PIM_GENERATED_PATH_SHIM_OUTPUTS})
|
||||
set(PIM_GENERATED_PATH_SHIM_TARGET OMPimGeneratedPathShims)
|
||||
endif ()
|
||||
|
||||
set(PIM_PUBLIC_INCLUDE_DIRS
|
||||
${ONNX_MLIR_SRC_ROOT}/include
|
||||
${ONNX_MLIR_SRC_ROOT}
|
||||
@@ -37,6 +87,9 @@ set(PIM_GENERATED_INCLUDE_DIRS
|
||||
|
||||
function(add_pim_library name)
|
||||
add_onnx_mlir_library(${name} STATIC ${ARGN})
|
||||
if (PIM_GENERATED_PATH_SHIM_TARGET)
|
||||
add_dependencies(${name} ${PIM_GENERATED_PATH_SHIM_TARGET})
|
||||
endif ()
|
||||
endfunction()
|
||||
|
||||
add_subdirectory(Dialect)
|
||||
@@ -68,5 +121,8 @@ add_pim_library(OMPIMAccel
|
||||
OMSpatialToPim
|
||||
OMPimCommon
|
||||
OMPimBufferization
|
||||
OMPimMemoryCoalescing
|
||||
OMPimHostConstantFolding
|
||||
OMPimVerification
|
||||
MLIRTensorInferTypeOpInterfaceImpl
|
||||
)
|
||||
|
||||
@@ -1,5 +1,19 @@
|
||||
add_pim_library(OMPimCommon
|
||||
PimCommon.cpp
|
||||
IR/AffineUtils.cpp
|
||||
IR/AddressAnalysis.cpp
|
||||
IR/BatchCoreUtils.cpp
|
||||
IR/ConstantUtils.cpp
|
||||
IR/CoreBlockUtils.cpp
|
||||
IR/EntryPointUtils.cpp
|
||||
IR/LoopUtils.cpp
|
||||
IR/ShapeUtils.cpp
|
||||
IR/SubviewUtils.cpp
|
||||
IR/WeightUtils.cpp
|
||||
Support/CheckedArithmetic.cpp
|
||||
Support/DebugDump.cpp
|
||||
Support/Diagnostics.cpp
|
||||
Support/FileSystemUtils.cpp
|
||||
Support/ReportUtils.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
@@ -7,6 +21,8 @@ add_pim_library(OMPimCommon
|
||||
${PIM_PUBLIC_INCLUDE_DIRS}
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRLinalgDialect
|
||||
MLIRSCFDialect
|
||||
onnx
|
||||
SpatialOps
|
||||
PimOps
|
||||
|
||||
@@ -0,0 +1,807 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp) {
|
||||
if (!moduleOp || !getGlobalOp)
|
||||
return {};
|
||||
return moduleOp.lookupSymbol<mlir::memref::GlobalOp>(getGlobalOp.getName());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||
if (!knowledge)
|
||||
return value;
|
||||
|
||||
auto iter = knowledge->aliases.find(value);
|
||||
while (iter != knowledge->aliases.end()) {
|
||||
value = iter->second;
|
||||
iter = knowledge->aliases.find(value);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value);
|
||||
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value);
|
||||
|
||||
template <typename... Args>
|
||||
CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) {
|
||||
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
if (mlir::isa<mlir::BlockArgument>(value))
|
||||
return value;
|
||||
|
||||
mlir::Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return value;
|
||||
|
||||
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
||||
if (auto result = mlir::dyn_cast<mlir::OpResult>(value))
|
||||
if (mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
|
||||
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
||||
}
|
||||
|
||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
||||
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
|
||||
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
|
||||
|
||||
static llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefStrides(mlir::MemRefType type) {
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
if (failed(type.getStridesAndOffset(strides, offset)))
|
||||
return mlir::failure();
|
||||
if (llvm::any_of(strides, mlir::ShapedType::isDynamic))
|
||||
return mlir::failure();
|
||||
return strides;
|
||||
}
|
||||
|
||||
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
|
||||
const StaticValueKnowledge* knowledge) {
|
||||
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return mlir::failure();
|
||||
|
||||
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||
return mlir::failure();
|
||||
|
||||
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
|
||||
if (!denseAttr || !globalType || !globalType.hasStaticShape())
|
||||
return mlir::failure();
|
||||
|
||||
auto elementType = denseAttr.getElementType();
|
||||
if (!elementType.isIndex() && !elementType.isInteger())
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t> indices;
|
||||
indices.reserve(loadOp.getIndices().size());
|
||||
for (mlir::Value index : loadOp.getIndices()) {
|
||||
auto resolvedIndex = resolveIndexValueImpl(index, knowledge);
|
||||
if (failed(resolvedIndex))
|
||||
return mlir::failure();
|
||||
indices.push_back(*resolvedIndex);
|
||||
}
|
||||
|
||||
if (indices.size() != static_cast<size_t>(globalType.getRank()))
|
||||
return mlir::failure();
|
||||
|
||||
auto strides = computeRowMajorStrides(globalType.getShape());
|
||||
int64_t linearIndex = linearizeIndex(indices, strides);
|
||||
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
|
||||
return mlir::failure();
|
||||
|
||||
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
|
||||
}
|
||||
|
||||
static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) {
|
||||
switch (predicate) {
|
||||
case mlir::arith::CmpIPredicate::eq: return lhs == rhs;
|
||||
case mlir::arith::CmpIPredicate::ne: return lhs != rhs;
|
||||
case mlir::arith::CmpIPredicate::slt: return lhs < rhs;
|
||||
case mlir::arith::CmpIPredicate::sle: return lhs <= rhs;
|
||||
case mlir::arith::CmpIPredicate::sgt: return lhs > rhs;
|
||||
case mlir::arith::CmpIPredicate::sge: return lhs >= rhs;
|
||||
case mlir::arith::CmpIPredicate::ult: return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
|
||||
case mlir::arith::CmpIPredicate::ule: return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
|
||||
case mlir::arith::CmpIPredicate::ugt: return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
|
||||
case mlir::arith::CmpIPredicate::uge: return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
|
||||
}
|
||||
|
||||
llvm_unreachable("unknown cmpi predicate");
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr,
|
||||
const StaticValueKnowledge& knowledge) {
|
||||
if (!expr.node)
|
||||
return mlir::failure();
|
||||
|
||||
switch (expr.node->kind) {
|
||||
case CompiledIndexExprNode::Kind::Constant: return expr.node->constant;
|
||||
case CompiledIndexExprNode::Kind::Symbol: {
|
||||
auto value = resolveAlias(expr.node->symbol, &knowledge);
|
||||
auto iter = knowledge.indexValues.find(value);
|
||||
if (iter != knowledge.indexValues.end())
|
||||
return iter->second;
|
||||
return mlir::failure();
|
||||
}
|
||||
case CompiledIndexExprNode::Kind::Add:
|
||||
case CompiledIndexExprNode::Kind::Sub:
|
||||
case CompiledIndexExprNode::Kind::Mul:
|
||||
case CompiledIndexExprNode::Kind::DivUI:
|
||||
case CompiledIndexExprNode::Kind::DivSI:
|
||||
case CompiledIndexExprNode::Kind::RemUI:
|
||||
case CompiledIndexExprNode::Kind::RemSI:
|
||||
case CompiledIndexExprNode::Kind::MinUI:
|
||||
case CompiledIndexExprNode::Kind::CmpI: {
|
||||
auto lhs = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
|
||||
auto rhs = evaluateCompiledIndexExpr(expr.node->operands[1], knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
|
||||
switch (expr.node->kind) {
|
||||
case CompiledIndexExprNode::Kind::Add: return *lhs + *rhs;
|
||||
case CompiledIndexExprNode::Kind::Sub: return *lhs - *rhs;
|
||||
case CompiledIndexExprNode::Kind::Mul: return *lhs * *rhs;
|
||||
case CompiledIndexExprNode::Kind::DivUI:
|
||||
if (*rhs == 0)
|
||||
return mlir::failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||
case CompiledIndexExprNode::Kind::DivSI:
|
||||
if (*rhs == 0 || (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1))
|
||||
return mlir::failure();
|
||||
return *lhs / *rhs;
|
||||
case CompiledIndexExprNode::Kind::RemUI:
|
||||
if (*rhs == 0)
|
||||
return mlir::failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||
case CompiledIndexExprNode::Kind::RemSI:
|
||||
if (*rhs == 0)
|
||||
return mlir::failure();
|
||||
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
|
||||
return 0;
|
||||
return *lhs % *rhs;
|
||||
case CompiledIndexExprNode::Kind::MinUI:
|
||||
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
|
||||
case CompiledIndexExprNode::Kind::CmpI: return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
|
||||
default: llvm_unreachable("unexpected binary compiled index kind");
|
||||
}
|
||||
}
|
||||
case CompiledIndexExprNode::Kind::Select: {
|
||||
auto condition = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
|
||||
if (failed(condition))
|
||||
return mlir::failure();
|
||||
return evaluateCompiledIndexExpr(*condition != 0 ? expr.node->operands[1] : expr.node->operands[2], knowledge);
|
||||
}
|
||||
case CompiledIndexExprNode::Kind::ConstantGlobalLoad: {
|
||||
if (!expr.node->globalOp || !expr.node->globalOp.getInitialValue())
|
||||
return mlir::failure();
|
||||
|
||||
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*expr.node->globalOp.getInitialValue());
|
||||
auto globalType = mlir::dyn_cast<mlir::MemRefType>(expr.node->globalOp.getType());
|
||||
if (!denseAttr || !globalType)
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t> indices;
|
||||
indices.reserve(expr.node->operands.size());
|
||||
for (const CompiledIndexExpr& operand : expr.node->operands) {
|
||||
auto resolvedIndex = evaluateCompiledIndexExpr(operand, knowledge);
|
||||
if (failed(resolvedIndex))
|
||||
return mlir::failure();
|
||||
indices.push_back(*resolvedIndex);
|
||||
}
|
||||
|
||||
int64_t linearIndex = linearizeIndex(indices, expr.node->globalStrides);
|
||||
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
|
||||
return mlir::failure();
|
||||
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
|
||||
}
|
||||
}
|
||||
|
||||
llvm_unreachable("unknown compiled index kind");
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledIndexExpr> compileConstantGlobalLoad(mlir::memref::LoadOp loadOp) {
|
||||
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return mlir::failure();
|
||||
|
||||
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||
return mlir::failure();
|
||||
|
||||
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
|
||||
if (!denseAttr || !globalType || !globalType.hasStaticShape())
|
||||
return mlir::failure();
|
||||
|
||||
auto elementType = denseAttr.getElementType();
|
||||
if (!elementType.isIndex() && !elementType.isInteger())
|
||||
return mlir::failure();
|
||||
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::ConstantGlobalLoad;
|
||||
expr.globalOp = globalOp;
|
||||
expr.globalStrides = computeRowMajorStrides(globalType.getShape());
|
||||
expr.operands.reserve(loadOp.getIndices().size());
|
||||
for (mlir::Value index : loadOp.getIndices()) {
|
||||
auto compiledIndex = compileIndexValueImpl(index);
|
||||
if (failed(compiledIndex))
|
||||
return mlir::failure();
|
||||
expr.operands.push_back(*compiledIndex);
|
||||
}
|
||||
return makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>()) {
|
||||
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue())) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = integerAttr.getInt();
|
||||
return makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
}
|
||||
|
||||
mlir::Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Symbol;
|
||||
expr.symbol = value;
|
||||
return makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
auto buildBinaryExpr = [&](CompiledIndexExprNode::Kind kind, mlir::Value lhsValue, mlir::Value rhsValue) {
|
||||
auto lhs = compileIndexValueImpl(lhsValue);
|
||||
auto rhs = compileIndexValueImpl(rhsValue);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return llvm::FailureOr<CompiledIndexExpr>(mlir::failure());
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = kind;
|
||||
expr.operands = {*lhs, *rhs};
|
||||
return llvm::FailureOr<CompiledIndexExpr>(makeCompiledIndexExpr(std::move(expr)));
|
||||
};
|
||||
|
||||
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
|
||||
return compileIndexValueImpl(indexCastOp.getIn());
|
||||
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::Add, addOp.getLhs(), addOp.getRhs());
|
||||
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::Sub, subOp.getLhs(), subOp.getRhs());
|
||||
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::Mul, mulOp.getLhs(), mulOp.getRhs());
|
||||
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivUI, divOp.getLhs(), divOp.getRhs());
|
||||
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivSI, divOp.getLhs(), divOp.getRhs());
|
||||
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::MinUI, minOp.getLhs(), minOp.getRhs());
|
||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemUI, remOp.getLhs(), remOp.getRhs());
|
||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemSI, remOp.getLhs(), remOp.getRhs());
|
||||
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
|
||||
auto expr = buildBinaryExpr(CompiledIndexExprNode::Kind::CmpI, cmpOp.getLhs(), cmpOp.getRhs());
|
||||
if (failed(expr))
|
||||
return mlir::failure();
|
||||
auto exprNode = std::make_shared<CompiledIndexExprNode>(*expr->node);
|
||||
exprNode->predicate = cmpOp.getPredicate();
|
||||
return CompiledIndexExpr(exprNode);
|
||||
}
|
||||
if (auto maxOp = mlir::dyn_cast<mlir::arith::MaxUIOp>(definingOp)) {
|
||||
auto lhs = compileIndexValueImpl(maxOp.getLhs());
|
||||
auto rhs = compileIndexValueImpl(maxOp.getRhs());
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
|
||||
CompiledIndexExprNode cmpExpr;
|
||||
cmpExpr.kind = CompiledIndexExprNode::Kind::CmpI;
|
||||
cmpExpr.predicate = mlir::arith::CmpIPredicate::uge;
|
||||
cmpExpr.operands = {*lhs, *rhs};
|
||||
|
||||
CompiledIndexExprNode selectExpr;
|
||||
selectExpr.kind = CompiledIndexExprNode::Kind::Select;
|
||||
selectExpr.operands = {makeCompiledIndexExpr(std::move(cmpExpr)), *lhs, *rhs};
|
||||
return makeCompiledIndexExpr(std::move(selectExpr));
|
||||
}
|
||||
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
|
||||
auto condition = compileIndexValueImpl(selectOp.getCondition());
|
||||
auto trueValue = compileIndexValueImpl(selectOp.getTrueValue());
|
||||
auto falseValue = compileIndexValueImpl(selectOp.getFalseValue());
|
||||
if (failed(condition) || failed(trueValue) || failed(falseValue))
|
||||
return mlir::failure();
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Select;
|
||||
expr.operands = {*condition, *trueValue, *falseValue};
|
||||
return makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
|
||||
return compileConstantGlobalLoad(loadOp);
|
||||
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Symbol;
|
||||
expr.symbol = value;
|
||||
return makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
if (knowledge) {
|
||||
auto iter = knowledge->indexValues.find(value);
|
||||
if (iter != knowledge->indexValues.end())
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>();
|
||||
if (constantOp) {
|
||||
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue()))
|
||||
return integerAttr.getInt();
|
||||
}
|
||||
|
||||
mlir::Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return mlir::failure();
|
||||
|
||||
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
|
||||
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
|
||||
|
||||
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
return *lhs + *rhs;
|
||||
}
|
||||
|
||||
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
return *lhs - *rhs;
|
||||
}
|
||||
|
||||
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
return *lhs * *rhs;
|
||||
}
|
||||
|
||||
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return mlir::failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return mlir::failure();
|
||||
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
|
||||
return mlir::failure();
|
||||
return *lhs / *rhs;
|
||||
}
|
||||
|
||||
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
|
||||
}
|
||||
|
||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return mlir::failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return mlir::failure();
|
||||
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
|
||||
return 0;
|
||||
return *lhs % *rhs;
|
||||
}
|
||||
|
||||
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(cmpOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(cmpOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
return evaluateCmpPredicate(cmpOp.getPredicate(), *lhs, *rhs) ? 1 : 0;
|
||||
}
|
||||
|
||||
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
|
||||
auto condition = resolveIndexValueImpl(selectOp.getCondition(), knowledge);
|
||||
if (failed(condition))
|
||||
return mlir::failure();
|
||||
return resolveIndexValueImpl(*condition != 0 ? selectOp.getTrueValue() : selectOp.getFalseValue(), knowledge);
|
||||
}
|
||||
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
|
||||
return resolveConstantGlobalLoad(loadOp, knowledge);
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(ofr)) {
|
||||
auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
|
||||
if (!integerAttr)
|
||||
return mlir::failure();
|
||||
return integerAttr.getInt();
|
||||
}
|
||||
|
||||
return resolveIndexValueImpl(mlir::cast<mlir::Value>(ofr), knowledge);
|
||||
}
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Value value,
|
||||
const StaticValueKnowledge* knowledge) {
|
||||
int64_t byteOffset = 0;
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
while (true) {
|
||||
if (mlir::isa<mlir::BlockArgument>(value))
|
||||
return ResolvedContiguousAddress {value, byteOffset};
|
||||
|
||||
mlir::Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return mlir::failure();
|
||||
|
||||
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
||||
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
return mlir::failure();
|
||||
value = resolveAlias(tiedOperand->get(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
|
||||
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
||||
if (!result)
|
||||
return mlir::failure();
|
||||
|
||||
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
||||
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
value = yieldedValue;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
||||
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
||||
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t> offsets;
|
||||
llvm::SmallVector<int64_t> sizes;
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
offsets.reserve(subviewOp.getMixedOffsets().size());
|
||||
sizes.reserve(subviewOp.getMixedSizes().size());
|
||||
strides.reserve(subviewOp.getMixedStrides().size());
|
||||
|
||||
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
|
||||
if (failed(resolvedOffset))
|
||||
return mlir::failure();
|
||||
offsets.push_back(*resolvedOffset);
|
||||
}
|
||||
|
||||
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto resolvedSize = resolveOpFoldResult(size, knowledge);
|
||||
if (failed(resolvedSize))
|
||||
return mlir::failure();
|
||||
sizes.push_back(*resolvedSize);
|
||||
}
|
||||
|
||||
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
|
||||
if (failed(resolvedStride))
|
||||
return mlir::failure();
|
||||
strides.push_back(*resolvedStride);
|
||||
}
|
||||
|
||||
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||
return mlir::failure();
|
||||
|
||||
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||
if (failed(sourceStrides))
|
||||
return mlir::failure();
|
||||
byteOffset += linearizeIndex(offsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
value = resolveAlias(subviewOp.getSource(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
|
||||
value = resolveAlias(castOp.getSource(), knowledge);
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
|
||||
value = resolveAlias(collapseOp.getSrc(), knowledge);
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
|
||||
value = resolveAlias(expandOp.getSrc(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
|
||||
return ResolvedContiguousAddress {value, byteOffset};
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value) {
|
||||
int64_t constantByteOffset = 0;
|
||||
CompiledIndexExpr byteOffsetExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = 0;
|
||||
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
while (true) {
|
||||
if (mlir::isa<mlir::BlockArgument>(value))
|
||||
return CompiledAddressExpr {value, byteOffsetExpr};
|
||||
|
||||
mlir::Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return mlir::failure();
|
||||
|
||||
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
||||
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
return mlir::failure();
|
||||
value = tiedOperand->get();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
|
||||
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
||||
if (!result)
|
||||
return mlir::failure();
|
||||
|
||||
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
mlir::Value yieldedValue = yieldOp.getOperand(result.getResultNumber());
|
||||
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
||||
value = forOp.getInitArgs()[blockArgument.getArgNumber() - 1];
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
value = yieldedValue;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
||||
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
||||
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t> staticSizes;
|
||||
staticSizes.reserve(subviewOp.getMixedSizes().size());
|
||||
llvm::SmallVector<int64_t> staticStrides;
|
||||
staticStrides.reserve(subviewOp.getMixedStrides().size());
|
||||
llvm::SmallVector<int64_t> staticOffsets;
|
||||
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
|
||||
bool hasOnlyStaticOffsets = true;
|
||||
|
||||
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets())
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
|
||||
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
else
|
||||
hasOnlyStaticOffsets = false;
|
||||
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto attr = mlir::dyn_cast<mlir::Attribute>(size);
|
||||
if (!attr)
|
||||
return mlir::failure();
|
||||
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
}
|
||||
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto attr = mlir::dyn_cast<mlir::Attribute>(stride);
|
||||
if (!attr)
|
||||
return mlir::failure();
|
||||
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
}
|
||||
|
||||
if (!isContiguousSubviewWithDynamicOffsets(
|
||||
sourceType.getShape(), subviewOp.getMixedOffsets(), staticSizes, staticStrides)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (hasOnlyStaticOffsets) {
|
||||
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
|
||||
return mlir::failure();
|
||||
|
||||
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||
if (failed(sourceStrides))
|
||||
return mlir::failure();
|
||||
constantByteOffset +=
|
||||
linearizeIndex(staticOffsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
}
|
||||
else {
|
||||
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||
if (failed(sourceStrides))
|
||||
return mlir::failure();
|
||||
CompiledIndexExpr offsetExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = 0;
|
||||
offsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) {
|
||||
CompiledIndexExpr operandExpr;
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(mixedOffset)) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = mlir::cast<mlir::IntegerAttr>(attr).getInt() * sourceStride
|
||||
* getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
operandExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
else {
|
||||
auto compiledOffset = compileIndexValueImpl(mlir::cast<mlir::Value>(mixedOffset));
|
||||
if (failed(compiledOffset))
|
||||
return mlir::failure();
|
||||
CompiledIndexExpr scaleExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = sourceStride * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
scaleExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Mul;
|
||||
expr.operands = {*compiledOffset, scaleExpr};
|
||||
operandExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Add;
|
||||
expr.operands = {offsetExpr, operandExpr};
|
||||
offsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
CompiledIndexExpr constantExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = constantByteOffset;
|
||||
constantExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Add;
|
||||
expr.operands = {constantExpr, offsetExpr};
|
||||
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
constantByteOffset = 0;
|
||||
}
|
||||
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp)) {
|
||||
if (constantByteOffset != 0) {
|
||||
CompiledIndexExpr constantExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = constantByteOffset;
|
||||
constantExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
if (byteOffsetExpr.node->kind == CompiledIndexExprNode::Kind::Constant && byteOffsetExpr.node->constant == 0)
|
||||
byteOffsetExpr = constantExpr;
|
||||
else {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Add;
|
||||
expr.operands = {constantExpr, byteOffsetExpr};
|
||||
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
}
|
||||
return CompiledAddressExpr {value, byteOffsetExpr};
|
||||
}
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveIndexValueImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value) { return compileIndexValueImpl(value); }
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
||||
const StaticValueKnowledge& knowledge) {
|
||||
return resolveContiguousAddressImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveLoopCarriedAliasImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value) {
|
||||
return compileContiguousAddressExprImpl(value);
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> CompiledIndexExpr::evaluate(const StaticValueKnowledge& knowledge) const {
|
||||
return evaluateCompiledIndexExpr(*this, knowledge);
|
||||
}
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge,
|
||||
std::optional<unsigned> lane) const {
|
||||
(void) lane;
|
||||
auto resolvedOffset = byteOffset.evaluate(knowledge);
|
||||
if (failed(resolvedOffset))
|
||||
return mlir::failure();
|
||||
return ResolvedContiguousAddress {base, *resolvedOffset};
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,94 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/// Describes a value as a base addressable object plus a statically known
|
||||
/// byte offset after peeling aliases, casts, and contiguous subviews.
|
||||
struct ResolvedContiguousAddress {
|
||||
mlir::Value base;
|
||||
int64_t byteOffset = 0;
|
||||
};
|
||||
|
||||
/// Records compile-time facts used when interpreting address arithmetic and
|
||||
/// loop-carried aliases inside PIM regions.
|
||||
struct StaticValueKnowledge {
|
||||
llvm::DenseMap<mlir::Value, int64_t> indexValues;
|
||||
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
|
||||
|
||||
StaticValueKnowledge() {}
|
||||
};
|
||||
|
||||
struct CompiledIndexExprNode;
|
||||
|
||||
struct CompiledIndexExpr {
|
||||
std::shared_ptr<CompiledIndexExprNode> node;
|
||||
|
||||
CompiledIndexExpr() = default;
|
||||
explicit CompiledIndexExpr(std::shared_ptr<CompiledIndexExprNode> node)
|
||||
: node(std::move(node)) {}
|
||||
|
||||
llvm::FailureOr<int64_t> evaluate(const StaticValueKnowledge& knowledge) const;
|
||||
};
|
||||
|
||||
struct CompiledIndexExprNode {
|
||||
enum class Kind {
|
||||
Constant,
|
||||
Symbol,
|
||||
Add,
|
||||
Sub,
|
||||
Mul,
|
||||
DivUI,
|
||||
DivSI,
|
||||
RemUI,
|
||||
RemSI,
|
||||
MinUI,
|
||||
CmpI,
|
||||
Select,
|
||||
ConstantGlobalLoad
|
||||
};
|
||||
|
||||
Kind kind = Kind::Constant;
|
||||
int64_t constant = 0;
|
||||
mlir::Value symbol;
|
||||
mlir::arith::CmpIPredicate predicate = mlir::arith::CmpIPredicate::eq;
|
||||
mlir::memref::GlobalOp globalOp;
|
||||
llvm::SmallVector<int64_t, 4> globalStrides;
|
||||
llvm::SmallVector<CompiledIndexExpr, 4> operands;
|
||||
};
|
||||
|
||||
struct CompiledAddressExpr {
|
||||
mlir::Value base;
|
||||
CompiledIndexExpr byteOffset;
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> evaluate(const StaticValueKnowledge& knowledge,
|
||||
std::optional<unsigned> lane) const;
|
||||
};
|
||||
|
||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||
|
||||
/// Resolves a value to contiguous backing storage when that storage can be
|
||||
/// proven statically from aliases, DPS ties, casts, and subviews.
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
||||
const StaticValueKnowledge& knowledge = {});
|
||||
|
||||
/// Statically evaluates index-like SSA values, including simple integer
|
||||
/// arithmetic and loop facts recorded in `knowledge`.
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {});
|
||||
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value);
|
||||
|
||||
/// Follows alias, view, and DPS chains to recover the backing value of a
|
||||
/// loop-carried memref/result.
|
||||
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||
|
||||
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,182 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "AffineUtils.hpp"
|
||||
#include "ConstantUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
static FailureOr<int64_t> floorDivSigned(int64_t lhs, int64_t rhs) {
|
||||
if (rhs <= 0)
|
||||
return failure();
|
||||
|
||||
int64_t quotient = lhs / rhs;
|
||||
int64_t remainder = lhs % rhs;
|
||||
if (remainder != 0 && lhs < 0)
|
||||
--quotient;
|
||||
return quotient;
|
||||
}
|
||||
|
||||
static FailureOr<int64_t> ceilDivSigned(int64_t lhs, int64_t rhs) {
|
||||
if (rhs <= 0)
|
||||
return failure();
|
||||
|
||||
int64_t quotient = lhs / rhs;
|
||||
int64_t remainder = lhs % rhs;
|
||||
if (remainder != 0 && lhs > 0)
|
||||
++quotient;
|
||||
return quotient;
|
||||
}
|
||||
|
||||
Value createOrFoldAffineApply(
|
||||
RewriterBase& rewriter, Location loc, AffineMap map, ValueRange operands, Operation* constantAnchor) {
|
||||
assert(constantAnchor && "expected a valid constant anchor");
|
||||
assert(map.getNumResults() == 1 && "affine.apply expects a single-result affine map");
|
||||
|
||||
SmallVector<Attribute> operandConstants;
|
||||
operandConstants.reserve(operands.size());
|
||||
for (Value operand : operands) {
|
||||
std::optional<int64_t> constantValue = matchConstantIndexValue(operand);
|
||||
if (!constantValue)
|
||||
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
|
||||
operandConstants.push_back(rewriter.getIndexAttr(*constantValue));
|
||||
}
|
||||
|
||||
SmallVector<Attribute> foldedResults;
|
||||
if (succeeded(map.constantFold(operandConstants, foldedResults)) && foldedResults.size() == 1)
|
||||
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
|
||||
return getOrCreateIndexConstant(rewriter, constantAnchor, constantResult.getInt());
|
||||
|
||||
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
|
||||
}
|
||||
|
||||
Value createOrFoldAffineApply(
|
||||
RewriterBase& rewriter, Location loc, AffineExpr expr, ValueRange dims, Operation* constantAnchor) {
|
||||
AffineMap map = AffineMap::get(/*dimCount=*/dims.size(), /*symbolCount=*/0, expr);
|
||||
return createOrFoldAffineApply(rewriter, loc, map, dims, constantAnchor);
|
||||
}
|
||||
|
||||
Value affineMulConst(RewriterBase& rewriter, Location loc, Value value, int64_t multiplier, Operation* constantAnchor) {
|
||||
assert(constantAnchor && "expected a valid constant anchor");
|
||||
if (multiplier == 0)
|
||||
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
|
||||
if (multiplier == 1)
|
||||
return value;
|
||||
|
||||
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
|
||||
return createOrFoldAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
Value affineModConst(RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
|
||||
assert(constantAnchor && "expected a valid constant anchor");
|
||||
assert(divisor > 0 && "expected a positive affine.mod divisor");
|
||||
if (divisor == 1)
|
||||
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
|
||||
|
||||
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
|
||||
return createOrFoldAffineApply(rewriter, loc, d0 % divisor, ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
Value affineFloorDivConst(
|
||||
RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
|
||||
assert(constantAnchor && "expected a valid constant anchor");
|
||||
assert(divisor > 0 && "expected a positive affine.floor_div divisor");
|
||||
if (divisor == 1)
|
||||
return value;
|
||||
|
||||
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
|
||||
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
FailureOr<int64_t> evaluateAffineExpr(AffineExpr expr, ArrayRef<int64_t> dims, ArrayRef<int64_t> symbols) {
|
||||
if (auto constant = dyn_cast<AffineConstantExpr>(expr))
|
||||
return constant.getValue();
|
||||
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
|
||||
unsigned position = dim.getPosition();
|
||||
if (position >= dims.size())
|
||||
return failure();
|
||||
return dims[position];
|
||||
}
|
||||
if (auto symbol = dyn_cast<AffineSymbolExpr>(expr)) {
|
||||
unsigned position = symbol.getPosition();
|
||||
if (position >= symbols.size())
|
||||
return failure();
|
||||
return symbols[position];
|
||||
}
|
||||
|
||||
auto binary = dyn_cast<AffineBinaryOpExpr>(expr);
|
||||
if (!binary)
|
||||
return failure();
|
||||
|
||||
FailureOr<int64_t> lhs = evaluateAffineExpr(binary.getLHS(), dims, symbols);
|
||||
FailureOr<int64_t> rhs = evaluateAffineExpr(binary.getRHS(), dims, symbols);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return failure();
|
||||
|
||||
switch (binary.getKind()) {
|
||||
case AffineExprKind::Add: return *lhs + *rhs;
|
||||
case AffineExprKind::Mul: return *lhs * *rhs;
|
||||
case AffineExprKind::FloorDiv: return floorDivSigned(*lhs, *rhs);
|
||||
case AffineExprKind::CeilDiv: return ceilDivSigned(*lhs, *rhs);
|
||||
case AffineExprKind::Mod: {
|
||||
FailureOr<int64_t> div = floorDivSigned(*lhs, *rhs);
|
||||
if (failed(div))
|
||||
return failure();
|
||||
return *lhs - *div * *rhs;
|
||||
}
|
||||
default: return failure();
|
||||
}
|
||||
}
|
||||
|
||||
FailureOr<int64_t> evaluateSingleResultAffineMap(AffineMap map, ArrayRef<int64_t> operands) {
|
||||
if (map.getNumResults() != 1 || operands.size() != map.getNumInputs())
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> dims(operands.data(), map.getNumDims());
|
||||
ArrayRef<int64_t> symbols(operands.data() + map.getNumDims(), map.getNumSymbols());
|
||||
return evaluateAffineExpr(map.getResult(0), dims, symbols);
|
||||
}
|
||||
|
||||
FailureOr<int64_t> evaluateAffineApply(affine::AffineApplyOp affineApply, IndexValueResolver resolver) {
|
||||
SmallVector<int64_t, 4> operands;
|
||||
operands.reserve(affineApply.getMapOperands().size());
|
||||
for (Value operand : affineApply.getMapOperands()) {
|
||||
FailureOr<int64_t> folded = resolver(operand);
|
||||
if (failed(folded))
|
||||
return failure();
|
||||
operands.push_back(*folded);
|
||||
}
|
||||
|
||||
return evaluateSingleResultAffineMap(affineApply.getAffineMap(), operands);
|
||||
}
|
||||
|
||||
bool isSingleResultSymbolFreeAffineMap(AffineMap map) { return map.getNumResults() == 1 && map.getNumSymbols() == 0; }
|
||||
|
||||
bool isDimAndConstantAffineExpr(AffineExpr expr) {
|
||||
switch (expr.getKind()) {
|
||||
case AffineExprKind::Constant:
|
||||
case AffineExprKind::DimId: return true;
|
||||
case AffineExprKind::SymbolId: return false;
|
||||
case AffineExprKind::Add: {
|
||||
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
return isDimAndConstantAffineExpr(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS());
|
||||
}
|
||||
case AffineExprKind::Mul: {
|
||||
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
return (isa<AffineConstantExpr>(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS()))
|
||||
|| (isa<AffineConstantExpr>(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS()));
|
||||
}
|
||||
case AffineExprKind::FloorDiv:
|
||||
case AffineExprKind::CeilDiv:
|
||||
case AffineExprKind::Mod: {
|
||||
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
return isa<AffineConstantExpr>(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS());
|
||||
}
|
||||
}
|
||||
|
||||
llvm_unreachable("unexpected affine expression kind");
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,55 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/FunctionExtras.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
using IndexValueResolver = llvm::function_ref<llvm::FailureOr<int64_t>(mlir::Value)>;
|
||||
|
||||
mlir::Value createOrFoldAffineApply(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::AffineMap map,
|
||||
mlir::ValueRange operands,
|
||||
mlir::Operation* constantAnchor);
|
||||
|
||||
mlir::Value createOrFoldAffineApply(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::AffineExpr expr,
|
||||
mlir::ValueRange dims,
|
||||
mlir::Operation* constantAnchor);
|
||||
|
||||
mlir::Value affineMulConst(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value value,
|
||||
int64_t multiplier,
|
||||
mlir::Operation* constantAnchor);
|
||||
|
||||
mlir::Value affineModConst(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value value,
|
||||
int64_t divisor,
|
||||
mlir::Operation* constantAnchor);
|
||||
|
||||
mlir::Value affineFloorDivConst(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value value,
|
||||
int64_t divisor,
|
||||
mlir::Operation* constantAnchor);
|
||||
|
||||
llvm::FailureOr<int64_t>
|
||||
evaluateAffineExpr(mlir::AffineExpr expr, llvm::ArrayRef<int64_t> dims, llvm::ArrayRef<int64_t> symbols = {});
|
||||
|
||||
llvm::FailureOr<int64_t> evaluateSingleResultAffineMap(mlir::AffineMap map, llvm::ArrayRef<int64_t> operands);
|
||||
|
||||
llvm::FailureOr<int64_t> evaluateAffineApply(mlir::affine::AffineApplyOp affineApply, IndexValueResolver resolver);
|
||||
|
||||
bool isSingleResultSymbolFreeAffineMap(mlir::AffineMap map);
|
||||
|
||||
bool isDimAndConstantAffineExpr(mlir::AffineExpr expr);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,32 @@
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto coreIdsAttr = coreBatchOp->getAttrOfType<mlir::DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
|
||||
return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
}
|
||||
|
||||
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
|
||||
llvm::SmallVector<int32_t> laneCoreIds;
|
||||
laneCoreIds.reserve(coreIds.size() / laneCount);
|
||||
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
|
||||
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
|
||||
return laneCoreIds;
|
||||
}
|
||||
|
||||
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex) {
|
||||
if (mlir::isa<pim::PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 3;
|
||||
if (mlir::isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 2;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool isExplicitDevToHostTargetOperand(mlir::Operation* op, unsigned operandIndex) {
|
||||
return mlir::isa<pim::PimMemCopyDevToHostOp>(op) && operandIndex == 2;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
|
||||
|
||||
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
|
||||
|
||||
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex);
|
||||
|
||||
bool isExplicitDevToHostTargetOperand(mlir::Operation* op, unsigned operandIndex);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,745 @@
|
||||
#ifndef ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
|
||||
#define ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
|
||||
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace compact_asm {
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
enum class ListDelimiter {
|
||||
Square,
|
||||
Paren
|
||||
};
|
||||
|
||||
inline ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
|
||||
if (delimiter == ListDelimiter::Square)
|
||||
return parser.parseLSquare();
|
||||
return parser.parseLParen();
|
||||
}
|
||||
|
||||
inline ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
|
||||
if (delimiter == ListDelimiter::Square)
|
||||
return parser.parseOptionalRSquare();
|
||||
return parser.parseOptionalRParen();
|
||||
}
|
||||
|
||||
template <typename StreamT>
|
||||
inline void printOpenDelimiter(StreamT& stream, ListDelimiter delimiter) {
|
||||
stream << (delimiter == ListDelimiter::Square ? "[" : "(");
|
||||
}
|
||||
|
||||
template <typename StreamT>
|
||||
inline void printCloseDelimiter(StreamT& stream, ListDelimiter delimiter) {
|
||||
stream << (delimiter == ListDelimiter::Square ? "]" : ")");
|
||||
}
|
||||
|
||||
template <typename EntryT, typename ParseEntryFn>
|
||||
inline ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<EntryT>& entries,
|
||||
ParseEntryFn parseEntry) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
|
||||
while (true) {
|
||||
EntryT entry;
|
||||
if (parseEntry(entry))
|
||||
return failure();
|
||||
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
entries.push_back(entry);
|
||||
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
inline ParseResult
|
||||
parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
|
||||
while (true) {
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
SmallVector<IntT> subgroup;
|
||||
if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup))
|
||||
return failure();
|
||||
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
llvm::append_range(values, subgroup);
|
||||
}
|
||||
else {
|
||||
int64_t first = 0;
|
||||
if (parser.parseInteger(first))
|
||||
return failure();
|
||||
|
||||
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||
int64_t last = 0;
|
||||
if (parser.parseInteger(last) || last < first)
|
||||
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
|
||||
|
||||
int64_t step = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("by"))) {
|
||||
if (parser.parseInteger(step) || step <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
|
||||
}
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
if ((last - first) % step != 0) {
|
||||
return parser.emitError(parser.getCurrentLocation(),
|
||||
"range end must be reachable from start using the given step");
|
||||
}
|
||||
|
||||
for (int64_t value = first; value <= last; value += step)
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
values.push_back(static_cast<IntT>(value));
|
||||
}
|
||||
else {
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
values.push_back(static_cast<IntT>(first));
|
||||
}
|
||||
}
|
||||
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
inline ParseResult
|
||||
parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
return parseCompressedIntegerEntries(parser, delimiter, values);
|
||||
}
|
||||
|
||||
template <typename RangeT, typename PrintEntryFn>
|
||||
inline void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
|
||||
for (size_t index = 0; index < entries.size();) {
|
||||
size_t runEnd = index + 1;
|
||||
while (runEnd < entries.size() && entries[runEnd] == entries[index])
|
||||
++runEnd;
|
||||
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printEntry(entries[index]);
|
||||
size_t runLength = runEnd - index;
|
||||
if (runLength > 1)
|
||||
printer << " x" << runLength;
|
||||
index = runEnd;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename StreamT, typename IntT>
|
||||
inline void printCompressedIntegerEntries(StreamT& stream, ArrayRef<IntT> values) {
|
||||
struct FlatCompression {
|
||||
enum class Kind {
|
||||
Single,
|
||||
EqualRun,
|
||||
Progression
|
||||
};
|
||||
|
||||
Kind kind = Kind::Single;
|
||||
size_t covered = 1;
|
||||
size_t repeatCount = 1;
|
||||
size_t progressionValueCount = 1;
|
||||
int64_t step = 1;
|
||||
IntT firstValue {};
|
||||
IntT lastValue {};
|
||||
};
|
||||
|
||||
auto computeFlatCompression = [&](size_t start) {
|
||||
FlatCompression compression;
|
||||
compression.firstValue = values[start];
|
||||
compression.lastValue = values[start];
|
||||
|
||||
auto findEqualRunEnd = [&](size_t runStart) {
|
||||
size_t runEnd = runStart + 1;
|
||||
while (runEnd < values.size() && values[runEnd] == values[runStart])
|
||||
++runEnd;
|
||||
return runEnd;
|
||||
};
|
||||
|
||||
size_t firstRunEnd = findEqualRunEnd(start);
|
||||
compression.repeatCount = firstRunEnd - start;
|
||||
size_t progressionEnd = firstRunEnd;
|
||||
int64_t step = 0;
|
||||
IntT lastValue = values[start];
|
||||
|
||||
if (firstRunEnd < values.size()) {
|
||||
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
|
||||
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[start]);
|
||||
if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) {
|
||||
progressionEnd = secondRunEnd;
|
||||
lastValue = values[firstRunEnd];
|
||||
size_t currentRunStart = secondRunEnd;
|
||||
while (currentRunStart < values.size()) {
|
||||
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
|
||||
if (currentRunEnd - currentRunStart != compression.repeatCount)
|
||||
break;
|
||||
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
|
||||
break;
|
||||
lastValue = values[currentRunStart];
|
||||
progressionEnd = currentRunEnd;
|
||||
currentRunStart = currentRunEnd;
|
||||
}
|
||||
}
|
||||
else {
|
||||
step = 0;
|
||||
}
|
||||
}
|
||||
|
||||
compression.covered = 1;
|
||||
if (progressionEnd > firstRunEnd) {
|
||||
size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount;
|
||||
if (progressionValueCount >= 3) {
|
||||
compression.kind = FlatCompression::Kind::Progression;
|
||||
compression.covered = progressionEnd - start;
|
||||
compression.progressionValueCount = progressionValueCount;
|
||||
compression.step = step;
|
||||
compression.lastValue = lastValue;
|
||||
return compression;
|
||||
}
|
||||
}
|
||||
|
||||
if (compression.repeatCount > 1) {
|
||||
compression.kind = FlatCompression::Kind::EqualRun;
|
||||
compression.covered = compression.repeatCount;
|
||||
return compression;
|
||||
}
|
||||
|
||||
return compression;
|
||||
};
|
||||
|
||||
auto findRepeatedSublist = [&](size_t start) {
|
||||
size_t bestLength = 0;
|
||||
size_t bestRepeatCount = 1;
|
||||
size_t remaining = values.size() - start;
|
||||
|
||||
for (size_t length = 2; length * 2 <= remaining; ++length) {
|
||||
size_t repeatCount = 1;
|
||||
ArrayRef<IntT> candidate = values.slice(start, length);
|
||||
while (start + (repeatCount + 1) * length <= values.size()
|
||||
&& llvm::equal(candidate, values.slice(start + repeatCount * length, length))) {
|
||||
++repeatCount;
|
||||
}
|
||||
|
||||
if (repeatCount <= 1)
|
||||
continue;
|
||||
|
||||
size_t covered = length * repeatCount;
|
||||
size_t bestCovered = bestLength * bestRepeatCount;
|
||||
if (covered > bestCovered || (covered == bestCovered && length < bestLength)) {
|
||||
bestLength = length;
|
||||
bestRepeatCount = repeatCount;
|
||||
}
|
||||
}
|
||||
|
||||
return std::pair(bestLength, bestRepeatCount);
|
||||
};
|
||||
|
||||
for (size_t index = 0; index < values.size();) {
|
||||
if (index != 0)
|
||||
stream << ", ";
|
||||
|
||||
FlatCompression flat = computeFlatCompression(index);
|
||||
auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index);
|
||||
size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount;
|
||||
if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) {
|
||||
printOpenDelimiter(stream, ListDelimiter::Paren);
|
||||
printCompressedIntegerEntries(stream, values.slice(index, sublistLength));
|
||||
printCloseDelimiter(stream, ListDelimiter::Paren);
|
||||
stream << " x" << sublistRepeatCount;
|
||||
index += repeatedSublistCoverage;
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (flat.kind) {
|
||||
case FlatCompression::Kind::Progression:
|
||||
stream << flat.firstValue << " to " << flat.lastValue;
|
||||
if (flat.step != 1)
|
||||
stream << " by " << flat.step;
|
||||
if (flat.repeatCount > 1)
|
||||
stream << " x" << flat.repeatCount;
|
||||
index += flat.covered;
|
||||
break;
|
||||
case FlatCompression::Kind::EqualRun:
|
||||
stream << flat.firstValue << " x" << flat.repeatCount;
|
||||
index += flat.covered;
|
||||
break;
|
||||
case FlatCompression::Kind::Single:
|
||||
stream << flat.firstValue;
|
||||
index += flat.covered;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename StreamT, typename IntT>
|
||||
inline void printCompressedIntegerSequence(StreamT& stream, ArrayRef<IntT> values, ListDelimiter delimiter) {
|
||||
printOpenDelimiter(stream, delimiter);
|
||||
printCompressedIntegerEntries(stream, values);
|
||||
printCloseDelimiter(stream, delimiter);
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
inline ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
|
||||
return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values);
|
||||
}
|
||||
|
||||
template <typename IntT>
|
||||
inline void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
|
||||
printCompressedIntegerSequence(printer, values, ListDelimiter::Square);
|
||||
}
|
||||
|
||||
inline void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) {
|
||||
for (size_t index = 0; index < values.size();) {
|
||||
size_t equalRunEnd = index + 1;
|
||||
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
|
||||
++equalRunEnd;
|
||||
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
if (equalRunEnd - index > 1) {
|
||||
printer.printOperand(values[index]);
|
||||
printer << " x" << (equalRunEnd - index);
|
||||
index = equalRunEnd;
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t rangeEnd = index + 1;
|
||||
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
|
||||
while (rangeEnd < values.size()) {
|
||||
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
|
||||
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|
||||
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) {
|
||||
break;
|
||||
}
|
||||
++rangeEnd;
|
||||
}
|
||||
}
|
||||
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
|
||||
while (rangeEnd < values.size()) {
|
||||
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
|
||||
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|
||||
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) {
|
||||
break;
|
||||
}
|
||||
++rangeEnd;
|
||||
}
|
||||
}
|
||||
|
||||
printer.printOperand(values[index]);
|
||||
if (rangeEnd - index >= 3) {
|
||||
printer << " to ";
|
||||
printer.printOperand(values[rangeEnd - 1]);
|
||||
}
|
||||
else if (rangeEnd - index == 2) {
|
||||
printer << ", ";
|
||||
printer.printOperand(values[index + 1]);
|
||||
}
|
||||
index = rangeEnd;
|
||||
}
|
||||
}
|
||||
|
||||
inline void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) {
|
||||
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
|
||||
}
|
||||
|
||||
inline void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
|
||||
printOpenDelimiter(printer, delimiter);
|
||||
printCompressedValueSequence(printer, values);
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
inline void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) {
|
||||
printOpenDelimiter(printer, delimiter);
|
||||
printCompressedTypeSequence(printer, types);
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty) {
|
||||
Type firstType;
|
||||
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
|
||||
if (!firstTypeResult.has_value()) {
|
||||
if (allowEmpty)
|
||||
return success();
|
||||
return parser.emitError(parser.getCurrentLocation(), "expected type");
|
||||
}
|
||||
if (failed(*firstTypeResult))
|
||||
return failure();
|
||||
|
||||
auto appendType = [&](Type type) -> ParseResult {
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
types.push_back(type);
|
||||
return success();
|
||||
};
|
||||
|
||||
if (appendType(firstType))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
Type nextType;
|
||||
if (parser.parseType(nextType) || appendType(nextType))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
inline ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
|
||||
OpAsmParser::UnresolvedOperand firstOperand,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||
OpAsmParser::UnresolvedOperand lastOperand;
|
||||
if (parser.parseOperand(lastOperand))
|
||||
return failure();
|
||||
if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number)
|
||||
return parser.emitError(parser.getCurrentLocation(), "invalid operand range");
|
||||
for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number)
|
||||
operands.push_back({firstOperand.location, firstOperand.name, number});
|
||||
return success();
|
||||
}
|
||||
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t index = 0; index < repeatCount; ++index)
|
||||
operands.push_back(firstOperand);
|
||||
return success();
|
||||
}
|
||||
|
||||
inline ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
OpAsmParser::UnresolvedOperand firstOperand;
|
||||
if (parser.parseOperand(firstOperand))
|
||||
return failure();
|
||||
return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands);
|
||||
}
|
||||
|
||||
inline ParseResult parseCompressedOperandList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
|
||||
while (true) {
|
||||
if (parseOneCompressedOperandEntry(parser, operands))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
inline ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOneCompressedOperandEntry(parser, operands))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedOperandEntry(parser, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
inline ParseResult parseCompressedTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
|
||||
if (parseCompressedTypeSequence(parser, types, /*allowEmpty=*/false))
|
||||
return failure();
|
||||
return parseOptionalCloseDelimiter(parser, delimiter);
|
||||
}
|
||||
|
||||
inline bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
|
||||
if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0)
|
||||
return false;
|
||||
|
||||
SmallVector<Value> valueVec(values.begin(), values.end());
|
||||
ArrayRef<Value> tuple(valueVec.data(), tupleSize);
|
||||
for (size_t index = tupleSize; index < values.size(); index += tupleSize)
|
||||
if (!llvm::equal(tuple, ArrayRef<Value>(valueVec).slice(index, tupleSize)))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
|
||||
if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0)
|
||||
return false;
|
||||
|
||||
SmallVector<Type> typeVec(types.begin(), types.end());
|
||||
ArrayRef<Type> tuple(typeVec.data(), tupleSize);
|
||||
for (size_t index = tupleSize; index < types.size(); index += tupleSize)
|
||||
if (!llvm::equal(tuple, ArrayRef<Type>(typeVec).slice(index, tupleSize)))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
inline void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize, ListDelimiter delimiter) {
|
||||
printOpenDelimiter(printer, delimiter);
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
for (size_t index = 0; index < tupleSize; ++index) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printer.printOperand(values[index]);
|
||||
}
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer << " x" << (values.size() / tupleSize);
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
inline void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize, ListDelimiter delimiter) {
|
||||
printOpenDelimiter(printer, delimiter);
|
||||
printOpenDelimiter(printer, ListDelimiter::Paren);
|
||||
for (size_t index = 0; index < tupleSize; ++index) {
|
||||
if (index != 0)
|
||||
printer << ", ";
|
||||
printer.printType(types[index]);
|
||||
}
|
||||
printCloseDelimiter(printer, ListDelimiter::Paren);
|
||||
printer << " x" << (types.size() / tupleSize);
|
||||
printCloseDelimiter(printer, delimiter);
|
||||
}
|
||||
|
||||
inline ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
|
||||
ListDelimiter delimiter,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
SmallVector<OpAsmParser::UnresolvedOperand> tupleOperands;
|
||||
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
llvm::append_range(operands, tupleOperands);
|
||||
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
if (parser.parseLParen())
|
||||
return failure();
|
||||
tupleOperands.clear();
|
||||
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
llvm::append_range(operands, tupleOperands);
|
||||
}
|
||||
return parseOptionalCloseDelimiter(parser, delimiter);
|
||||
}
|
||||
|
||||
while (true) {
|
||||
if (parseOneCompressedOperandEntry(parser, operands))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
inline ParseResult
|
||||
parseCompressedOrTupleTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
|
||||
if (parseOpenDelimiter(parser, delimiter))
|
||||
return failure();
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
SmallVector<Type> tupleTypes;
|
||||
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
llvm::append_range(types, tupleTypes);
|
||||
|
||||
while (succeeded(parser.parseOptionalComma())) {
|
||||
if (parser.parseLParen())
|
||||
return failure();
|
||||
tupleTypes.clear();
|
||||
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
|
||||
return failure();
|
||||
|
||||
repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
llvm::append_range(types, tupleTypes);
|
||||
}
|
||||
return parseOptionalCloseDelimiter(parser, delimiter);
|
||||
}
|
||||
|
||||
while (true) {
|
||||
Type type;
|
||||
if (parser.parseType(type))
|
||||
return failure();
|
||||
|
||||
int64_t repeatCount = 1;
|
||||
if (succeeded(parser.parseOptionalKeyword("x"))) {
|
||||
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
|
||||
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
|
||||
}
|
||||
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
|
||||
types.push_back(type);
|
||||
|
||||
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
|
||||
return success();
|
||||
if (parser.parseComma())
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
inline void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) {
|
||||
if (block.getNumArguments() == 0) {
|
||||
printer << "() = ()";
|
||||
return;
|
||||
}
|
||||
|
||||
if (block.getNumArguments() == 1) {
|
||||
printer.printOperand(block.getArgument(0));
|
||||
printer << " = ";
|
||||
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||
return;
|
||||
}
|
||||
|
||||
printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren);
|
||||
printer << " = ";
|
||||
printCompressedValueList(printer, operands, ListDelimiter::Paren);
|
||||
}
|
||||
|
||||
inline ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser,
|
||||
OpAsmParser::Argument firstArgument,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
if (succeeded(parser.parseOptionalKeyword("to"))) {
|
||||
OpAsmParser::Argument lastArgument;
|
||||
if (parser.parseArgument(lastArgument))
|
||||
return failure();
|
||||
if (firstArgument.ssaName.name != lastArgument.ssaName.name
|
||||
|| firstArgument.ssaName.number > lastArgument.ssaName.number) {
|
||||
return parser.emitError(parser.getCurrentLocation(), "invalid argument range");
|
||||
}
|
||||
for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) {
|
||||
OpAsmParser::Argument argument;
|
||||
argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number};
|
||||
arguments.push_back(argument);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
arguments.push_back(firstArgument);
|
||||
return success();
|
||||
}
|
||||
|
||||
inline ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
OpAsmParser::Argument firstArgument;
|
||||
if (parser.parseArgument(firstArgument))
|
||||
return failure();
|
||||
return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments);
|
||||
}
|
||||
|
||||
inline void applyArgumentTypes(ArrayRef<Type> inputTypes, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
|
||||
for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes))
|
||||
argument.type = inputType;
|
||||
}
|
||||
|
||||
inline ParseResult parseArgumentBindings(OpAsmParser& parser,
|
||||
SmallVectorImpl<OpAsmParser::Argument>& arguments,
|
||||
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
if (succeeded(parser.parseOptionalRParen())) {
|
||||
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
OpAsmParser::Argument firstArgument;
|
||||
if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments))
|
||||
return failure();
|
||||
while (succeeded(parser.parseOptionalComma()))
|
||||
if (parseOneCompressedArgumentEntry(parser, arguments))
|
||||
return failure();
|
||||
if (parser.parseRParen() || parser.parseEqual()
|
||||
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
OpAsmParser::Argument argument;
|
||||
if (parser.parseArgument(argument) || parser.parseEqual()
|
||||
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
|
||||
return failure();
|
||||
}
|
||||
arguments.push_back(argument);
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace compact_asm
|
||||
} // namespace onnx_mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,157 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "ConstantUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
static std::optional<int64_t> getIndexConstantValue(arith::ConstantOp constantOp) {
|
||||
if (!constantOp.getType().isIndex())
|
||||
return std::nullopt;
|
||||
|
||||
auto intAttr = dyn_cast<IntegerAttr>(constantOp.getValue());
|
||||
if (!intAttr || !intAttr.getType().isIndex())
|
||||
return std::nullopt;
|
||||
|
||||
return intAttr.getInt();
|
||||
}
|
||||
|
||||
Block* getConstantInsertionBlock(Operation* anchorOp) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
|
||||
if (auto funcOp = dyn_cast<func::FuncOp>(anchorOp))
|
||||
return &funcOp.getBody().front();
|
||||
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
|
||||
return &funcOp.getBody().front();
|
||||
if (auto moduleOp = dyn_cast<ModuleOp>(anchorOp))
|
||||
return moduleOp.getBody();
|
||||
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
|
||||
return moduleOp.getBody();
|
||||
return anchorOp->getBlock();
|
||||
}
|
||||
|
||||
Value getOrCreateConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
Block* hostBlock = getConstantInsertionBlock(anchorOp);
|
||||
for (Operation& op : *hostBlock) {
|
||||
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
|
||||
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
|
||||
continue;
|
||||
return constantOp.getResult();
|
||||
}
|
||||
|
||||
auto* arithDialect = anchorOp->getContext()->getOrLoadDialect<arith::ArithDialect>();
|
||||
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
|
||||
}
|
||||
|
||||
Value getOrCreateConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
Block* hostBlock = getConstantInsertionBlock(anchorOp);
|
||||
for (Operation& op : *hostBlock) {
|
||||
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
|
||||
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
|
||||
continue;
|
||||
return constantOp.getResult();
|
||||
}
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(hostBlock);
|
||||
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
|
||||
}
|
||||
|
||||
Value getOrCreateConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) {
|
||||
return getOrCreateConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType());
|
||||
}
|
||||
|
||||
Value getOrCreateIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
|
||||
}
|
||||
|
||||
Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
|
||||
}
|
||||
|
||||
void hoistAndUniquifyIndexConstants(func::FuncOp funcOp, RewriterBase& rewriter) {
|
||||
if (funcOp.getBody().empty())
|
||||
return;
|
||||
|
||||
Block& entryBlock = funcOp.getBody().front();
|
||||
DenseMap<int64_t, Value> canonicalByValue;
|
||||
SmallVector<arith::ConstantOp> constants;
|
||||
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
if (!getIndexConstantValue(constantOp))
|
||||
return;
|
||||
constants.push_back(constantOp);
|
||||
});
|
||||
|
||||
for (arith::ConstantOp constantOp : constants) {
|
||||
auto value = getIndexConstantValue(constantOp);
|
||||
if (!value || constantOp->getBlock() != &entryBlock)
|
||||
continue;
|
||||
canonicalByValue.try_emplace(*value, constantOp.getResult());
|
||||
}
|
||||
|
||||
for (arith::ConstantOp constantOp : constants) {
|
||||
auto value = getIndexConstantValue(constantOp);
|
||||
if (!value)
|
||||
continue;
|
||||
|
||||
Value canonical = canonicalByValue.lookup(*value);
|
||||
if (!canonical) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&entryBlock);
|
||||
Builder builder(funcOp.getContext());
|
||||
canonical =
|
||||
arith::ConstantOp::create(rewriter, constantOp.getLoc(), builder.getIndexType(), builder.getIndexAttr(*value));
|
||||
canonicalByValue[*value] = canonical;
|
||||
}
|
||||
|
||||
if (constantOp.getResult() == canonical)
|
||||
continue;
|
||||
|
||||
constantOp.getResult().replaceAllUsesWith(canonical);
|
||||
}
|
||||
|
||||
for (arith::ConstantOp constantOp : llvm::reverse(constants)) {
|
||||
auto value = getIndexConstantValue(constantOp);
|
||||
if (!value)
|
||||
continue;
|
||||
if (constantOp.getResult() == canonicalByValue.lookup(*value))
|
||||
continue;
|
||||
if (constantOp.use_empty())
|
||||
rewriter.eraseOp(constantOp);
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<int64_t> matchConstantIndexValue(Value value) {
|
||||
if (!value || !value.getType().isIndex())
|
||||
return std::nullopt;
|
||||
|
||||
if (auto constant = value.getDefiningOp<arith::ConstantIndexOp>())
|
||||
return constant.value();
|
||||
|
||||
if (auto constant = value.getDefiningOp<arith::ConstantOp>())
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>(constant.getValue()); intAttr && intAttr.getType().isIndex())
|
||||
return intAttr.getInt();
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<int64_t> matchConstantIndexValue(OpFoldResult value) {
|
||||
if (auto attr = dyn_cast<Attribute>(value))
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>(attr); intAttr && intAttr.getType().isIndex())
|
||||
return intAttr.getInt();
|
||||
if (auto operand = dyn_cast<Value>(value))
|
||||
return matchConstantIndexValue(operand);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,33 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::Block* getConstantInsertionBlock(mlir::Operation* anchorOp);
|
||||
|
||||
mlir::Value
|
||||
getOrCreateConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, mlir::Attribute value, mlir::Type type);
|
||||
|
||||
mlir::Value
|
||||
getOrCreateConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, mlir::Attribute value, mlir::Type type);
|
||||
|
||||
mlir::Value getOrCreateConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp);
|
||||
|
||||
mlir::Value getOrCreateIndexConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value);
|
||||
|
||||
mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value);
|
||||
|
||||
void hoistAndUniquifyIndexConstants(mlir::func::FuncOp funcOp, mlir::RewriterBase& rewriter);
|
||||
|
||||
std::optional<int64_t> matchConstantIndexValue(mlir::Value value);
|
||||
|
||||
std::optional<int64_t> matchConstantIndexValue(mlir::OpFoldResult value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,137 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool isCoreStaticAddressOp(mlir::Operation* op) {
|
||||
if (mlir::isa<mlir::arith::ConstantOp,
|
||||
mlir::arith::AddIOp,
|
||||
mlir::arith::SubIOp,
|
||||
mlir::arith::MulIOp,
|
||||
mlir::arith::DivUIOp,
|
||||
mlir::arith::DivSIOp,
|
||||
mlir::arith::MinUIOp,
|
||||
mlir::arith::RemUIOp,
|
||||
mlir::arith::RemSIOp,
|
||||
mlir::arith::IndexCastOp,
|
||||
mlir::arith::CmpIOp,
|
||||
mlir::memref::AllocOp,
|
||||
mlir::memref::SubViewOp,
|
||||
mlir::memref::CastOp,
|
||||
mlir::memref::CollapseShapeOp,
|
||||
mlir::memref::ExpandShapeOp>(op))
|
||||
return true;
|
||||
|
||||
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(op))
|
||||
return selectOp.getType().isIntOrIndex();
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
walkPimCoreBlock(mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
|
||||
bool hasFailure = false;
|
||||
for (mlir::Operation& op : block) {
|
||||
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||
continue;
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
|
||||
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
|
||||
continue;
|
||||
|
||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
|
||||
mlir::Block& loopBody = forOp.getRegion().front();
|
||||
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
|
||||
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
|
||||
auto step = resolveIndexValue(forOp.getStep(), knowledge);
|
||||
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
|
||||
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
|
||||
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
|
||||
StaticValueKnowledge loopKnowledge = knowledge;
|
||||
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
|
||||
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
|
||||
loopKnowledge.aliases[iterArg] = iterValue;
|
||||
|
||||
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
|
||||
hasFailure = true;
|
||||
|
||||
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(loopBody.getTerminator());
|
||||
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
|
||||
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (failed(callback(op, knowledge)))
|
||||
hasFailure = true;
|
||||
}
|
||||
return mlir::success(!hasFailure);
|
||||
}
|
||||
|
||||
mlir::LogicalResult walkPimCoreBlockStructurally(
|
||||
mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
|
||||
bool hasFailure = false;
|
||||
for (mlir::Operation& op : block) {
|
||||
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||
continue;
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
|
||||
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
|
||||
continue;
|
||||
|
||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
|
||||
mlir::Block& loopBody = forOp.getRegion().front();
|
||||
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
|
||||
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
|
||||
auto step = resolveIndexValue(forOp.getStep(), knowledge);
|
||||
if (failed(lowerBound) || failed(upperBound) || failed(step)) {
|
||||
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM verification");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
if (*step <= 0) {
|
||||
forOp.emitOpError("requires positive scf.for step for PIM verification");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t, 2> samples;
|
||||
if (*lowerBound < *upperBound) {
|
||||
samples.push_back(*lowerBound);
|
||||
int64_t last = *lowerBound + ((*upperBound - 1 - *lowerBound) / *step) * *step;
|
||||
if (last != *lowerBound)
|
||||
samples.push_back(last);
|
||||
}
|
||||
|
||||
for (int64_t inductionValue : samples) {
|
||||
StaticValueKnowledge loopKnowledge = knowledge;
|
||||
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
|
||||
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), forOp.getInitArgs()))
|
||||
loopKnowledge.aliases[iterArg] = iterValue;
|
||||
|
||||
if (failed(walkPimCoreBlockStructurally(loopBody, loopKnowledge, callback)))
|
||||
hasFailure = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (failed(callback(op, knowledge)))
|
||||
hasFailure = true;
|
||||
}
|
||||
return mlir::success(!hasFailure);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,32 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/// Returns true for ops in a `pim.core` body that only participate in static
|
||||
/// address or index computation and therefore do not emit PIM instructions.
|
||||
bool isCoreStaticAddressOp(mlir::Operation* op);
|
||||
|
||||
/// Walks a `pim.core` body, statically unrolling nested `scf.for` loops when
|
||||
/// their bounds are known and invoking `callback` only on instruction-emitting
|
||||
/// operations.
|
||||
mlir::LogicalResult
|
||||
walkPimCoreBlock(mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
||||
|
||||
/// Walks a `pim.core`-like body structurally for verification without
|
||||
/// enumerating full loop trip counts. Loop bounds must still be statically
|
||||
/// evaluable so address resolution remains well-defined.
|
||||
mlir::LogicalResult walkPimCoreBlockStructurally(
|
||||
mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,45 @@
|
||||
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp) {
|
||||
if (!moduleOp)
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<mlir::ONNXEntryPointOp> entryPoints(moduleOp.getOps<mlir::ONNXEntryPointOp>());
|
||||
if (entryPoints.size() > 1) {
|
||||
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
|
||||
return mlir::failure();
|
||||
}
|
||||
if (!entryPoints.empty()) {
|
||||
auto entryPointAttr =
|
||||
entryPoints.front()->getAttrOfType<mlir::SymbolRefAttr>(mlir::ONNXEntryPointOp::getEntryPointFuncAttrName());
|
||||
if (!entryPointAttr) {
|
||||
entryPoints.front().emitOpError("is missing the entry point function attribute");
|
||||
return mlir::failure();
|
||||
}
|
||||
auto entryFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>(entryPointAttr.getLeafReference().getValue());
|
||||
if (!entryFunc) {
|
||||
entryPoints.front().emitOpError("references an unknown entry function ")
|
||||
<< entryPointAttr.getLeafReference().getValue();
|
||||
return mlir::failure();
|
||||
}
|
||||
return entryFunc;
|
||||
}
|
||||
|
||||
if (auto mainGraphFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>("main_graph"))
|
||||
return mainGraphFunc;
|
||||
|
||||
llvm::SmallVector<mlir::func::FuncOp> nonExternalFuncs;
|
||||
for (auto funcOp : moduleOp.getOps<mlir::func::FuncOp>())
|
||||
if (!funcOp.isExternal())
|
||||
nonExternalFuncs.push_back(funcOp);
|
||||
if (nonExternalFuncs.size() == 1)
|
||||
return nonExternalFuncs.front();
|
||||
|
||||
moduleOp.emitError("could not resolve a unique PIM entry function");
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/// Resolves the function the PIM pipeline should treat as its entry point.
|
||||
/// Prefers ONNX entry-point metadata, then `main_graph`, then the only
|
||||
/// non-external function if the module is otherwise unambiguous.
|
||||
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,96 @@
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
#include "llvm/Support/MathExtras.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "ConstantUtils.hpp"
|
||||
#include "LoopUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static std::optional<int64_t> getStaticTripCount(Value lowerBound, Value upperBound, Value step) {
|
||||
auto lower = matchConstantIndexValue(lowerBound);
|
||||
auto upper = matchConstantIndexValue(upperBound);
|
||||
auto stepValue = matchConstantIndexValue(step);
|
||||
if (!lower || !upper || !stepValue)
|
||||
return std::nullopt;
|
||||
if (*stepValue <= 0)
|
||||
return std::nullopt;
|
||||
if (*upper <= *lower)
|
||||
return int64_t {0};
|
||||
return llvm::divideCeil(*upper - *lower, *stepValue);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
static LogicalResult validateNormalizedLoopYields(Location loc, ValueRange initArgs, ArrayRef<Value> yieldedValues) {
|
||||
if (yieldedValues.size() == initArgs.size())
|
||||
return success();
|
||||
|
||||
emitError(loc) << "normalized loop body yielded " << yieldedValues.size() << " values for " << initArgs.size()
|
||||
<< " iter args";
|
||||
return failure();
|
||||
}
|
||||
|
||||
FailureOr<NormalizedLoopResult> buildNormalizedScfFor(OpBuilder& builder,
|
||||
Location loc,
|
||||
Value lowerBound,
|
||||
Value upperBound,
|
||||
Value step,
|
||||
ValueRange initArgs,
|
||||
NormalizedLoopBodyBuilder bodyBuilder) {
|
||||
NormalizedLoopResult result;
|
||||
|
||||
if (auto stepValue = matchConstantIndexValue(step); stepValue && *stepValue <= 0) {
|
||||
emitError(loc) << "normalized scf.for requires a positive step, got " << *stepValue;
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (auto tripCount = getStaticTripCount(lowerBound, upperBound, step)) {
|
||||
if (*tripCount == 0) {
|
||||
llvm::append_range(result.results, initArgs);
|
||||
return result;
|
||||
}
|
||||
|
||||
if (*tripCount == 1) {
|
||||
result.inductionVar = lowerBound;
|
||||
if (failed(bodyBuilder(builder, loc, lowerBound, initArgs, result.results)))
|
||||
return failure();
|
||||
if (failed(validateNormalizedLoopYields(loc, initArgs, result.results)))
|
||||
return failure();
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
result.loop = scf::ForOp::create(builder, loc, lowerBound, upperBound, step, initArgs);
|
||||
result.inductionVar = result.loop.getInductionVar();
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
Block* body = result.loop.getBody();
|
||||
if (!body->empty())
|
||||
if (auto yieldOp = dyn_cast<scf::YieldOp>(body->back()))
|
||||
yieldOp->erase();
|
||||
builder.setInsertionPointToEnd(body);
|
||||
ValueRange iterArgs = result.loop.getRegionIterArgs();
|
||||
if (failed(bodyBuilder(builder, loc, result.inductionVar, iterArgs, result.results))) {
|
||||
result.loop.erase();
|
||||
return failure();
|
||||
}
|
||||
if (failed(validateNormalizedLoopYields(loc, initArgs, result.results))) {
|
||||
result.loop.erase();
|
||||
return failure();
|
||||
}
|
||||
scf::YieldOp::create(builder, loc, result.results);
|
||||
}
|
||||
builder.setInsertionPointAfter(result.loop);
|
||||
result.results.assign(result.loop.getResults().begin(), result.loop.getResults().end());
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct NormalizedLoopResult {
|
||||
mlir::Value inductionVar;
|
||||
llvm::SmallVector<mlir::Value, 4> results;
|
||||
mlir::scf::ForOp loop;
|
||||
|
||||
bool wasInlined() const { return !loop; }
|
||||
};
|
||||
|
||||
using NormalizedLoopBodyBuilder = llvm::function_ref<mlir::LogicalResult(
|
||||
mlir::OpBuilder&, mlir::Location, mlir::Value, mlir::ValueRange, llvm::SmallVectorImpl<mlir::Value>&)>;
|
||||
|
||||
mlir::FailureOr<NormalizedLoopResult> buildNormalizedScfFor(mlir::OpBuilder& builder,
|
||||
mlir::Location loc,
|
||||
mlir::Value lowerBound,
|
||||
mlir::Value upperBound,
|
||||
mlir::Value step,
|
||||
mlir::ValueRange initArgs,
|
||||
NormalizedLoopBodyBuilder bodyBuilder);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,166 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape) {
|
||||
llvm::SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t>
|
||||
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides) {
|
||||
llvm::SmallVector<int64_t> indices(shape.size(), 0);
|
||||
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
||||
indices[dim] = linearIndex / stride;
|
||||
linearIndex %= stride;
|
||||
}
|
||||
return indices;
|
||||
}
|
||||
|
||||
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides) {
|
||||
int64_t linearIndex = 0;
|
||||
for (auto [index, stride] : llvm::zip_equal(indices, strides))
|
||||
linearIndex += index * stride;
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t dim : shape)
|
||||
numElements *= dim;
|
||||
return numElements;
|
||||
}
|
||||
|
||||
bool hasByteSizedElementType(mlir::Type elementType) {
|
||||
if (mlir::isa<mlir::IndexType>(elementType))
|
||||
return true;
|
||||
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
|
||||
return intType.getWidth() > 0 && intType.getWidth() % 8 == 0;
|
||||
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
|
||||
return floatType.getWidth() > 0 && floatType.getWidth() % 8 == 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t getElementTypeSizeInBytes(mlir::Type elementType) {
|
||||
if (mlir::isa<mlir::IndexType>(elementType))
|
||||
return mlir::IndexType::kInternalStorageBitWidth / 8;
|
||||
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
|
||||
return static_cast<size_t>(intType.getWidth() / 8);
|
||||
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
|
||||
return static_cast<size_t>(floatType.getWidth() / 8);
|
||||
llvm_unreachable("expected byte-sized integer, float, or index element type");
|
||||
}
|
||||
|
||||
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType) {
|
||||
return static_cast<size_t>(shapedType.getNumElements()) * getElementTypeSizeInBytes(shapedType.getElementType());
|
||||
}
|
||||
|
||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
llvm::ArrayRef<int64_t> offsets,
|
||||
llvm::ArrayRef<int64_t> sizes,
|
||||
llvm::ArrayRef<int64_t> strides) {
|
||||
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
||||
return false;
|
||||
|
||||
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
||||
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||
|
||||
auto firstNonZeroOffset = std::find_if(
|
||||
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
||||
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
||||
return offset != 0;
|
||||
});
|
||||
|
||||
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
||||
auto [offset, size, dimension] = *firstNonZeroOffset;
|
||||
if (size > dimension - offset)
|
||||
return false;
|
||||
++firstNonZeroOffset;
|
||||
|
||||
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
||||
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
||||
return size != 1;
|
||||
}))
|
||||
return false;
|
||||
}
|
||||
|
||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||
|
||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||
auto [size, dimension] = sizeAndShape;
|
||||
return size != dimension;
|
||||
});
|
||||
|
||||
if (firstDifferentSize != sizesAndShape.end()) {
|
||||
++firstDifferentSize;
|
||||
|
||||
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
||||
auto [size, _dimension] = sizeAndShape;
|
||||
return size != 1;
|
||||
}))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
|
||||
llvm::ArrayRef<int64_t> staticSizes,
|
||||
llvm::ArrayRef<int64_t> staticStrides) {
|
||||
if (sourceShape.size() != mixedOffsets.size() || sourceShape.size() != staticSizes.size()
|
||||
|| sourceShape.size() != staticStrides.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (llvm::any_of(staticStrides, [](int64_t stride) { return stride != 1; }))
|
||||
return false;
|
||||
|
||||
auto reversedTriples =
|
||||
llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(mixedOffsets), llvm::reverse(staticSizes));
|
||||
|
||||
auto firstNonZeroOrDynamicOffset = llvm::find_if(reversedTriples, [](auto triple) {
|
||||
auto [_sourceDim, offset, _size] = triple;
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
|
||||
return mlir::cast<mlir::IntegerAttr>(attr).getInt() != 0;
|
||||
return true;
|
||||
});
|
||||
|
||||
if (firstNonZeroOrDynamicOffset != reversedTriples.end()) {
|
||||
auto [sourceDim, offset, size] = *firstNonZeroOrDynamicOffset;
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
|
||||
int64_t staticOffset = mlir::cast<mlir::IntegerAttr>(attr).getInt();
|
||||
if (size > sourceDim - staticOffset)
|
||||
return false;
|
||||
}
|
||||
|
||||
++firstNonZeroOrDynamicOffset;
|
||||
for (auto it = firstNonZeroOrDynamicOffset; it != reversedTriples.end(); ++it)
|
||||
if (std::get<2>(*it) != 1)
|
||||
return false;
|
||||
}
|
||||
|
||||
auto reversedSizes = llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(staticSizes));
|
||||
auto firstDifferentSize = llvm::find_if(reversedSizes, [](auto pair) {
|
||||
auto [sourceDim, size] = pair;
|
||||
return size != sourceDim;
|
||||
});
|
||||
|
||||
if (firstDifferentSize != reversedSizes.end()) {
|
||||
++firstDifferentSize;
|
||||
for (auto it = firstDifferentSize; it != reversedSizes.end(); ++it)
|
||||
if (std::get<1>(*it) != 1)
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,39 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
llvm::SmallVector<int64_t>
|
||||
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
bool hasByteSizedElementType(mlir::Type elementType);
|
||||
|
||||
size_t getElementTypeSizeInBytes(mlir::Type elementType);
|
||||
|
||||
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
||||
|
||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
llvm::ArrayRef<int64_t> offsets,
|
||||
llvm::ArrayRef<int64_t> sizes,
|
||||
llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
|
||||
llvm::ArrayRef<int64_t> staticSizes,
|
||||
llvm::ArrayRef<int64_t> staticStrides);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,106 @@
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Value stripMemRefCasts(Value value) {
|
||||
while (auto castOp = value.getDefiningOp<memref::CastOp>())
|
||||
value = castOp.getSource();
|
||||
return value;
|
||||
}
|
||||
|
||||
Value stripMemRefViewOps(Value value) {
|
||||
while (true) {
|
||||
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
Value stripMemRefAddressingOps(Value value) {
|
||||
while (true) {
|
||||
if (auto subviewOp = value.getDefiningOp<memref::SubViewOp>()) {
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
Value strippedValue = stripMemRefViewOps(value);
|
||||
if (strippedValue == value)
|
||||
return value;
|
||||
value = strippedValue;
|
||||
}
|
||||
}
|
||||
|
||||
bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
|
||||
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
|
||||
}
|
||||
|
||||
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
|
||||
value = stripMemRefViewOps(value);
|
||||
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
|
||||
if (!subviewOp)
|
||||
return failure();
|
||||
|
||||
auto source = stripMemRefCasts(subviewOp.getSource());
|
||||
auto sourceType = dyn_cast<MemRefType>(source.getType());
|
||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
StaticSubviewInfo info;
|
||||
info.source = source;
|
||||
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
|
||||
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
|
||||
info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end());
|
||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto staticSize = getConstantIntValue(size);
|
||||
if (!staticSize)
|
||||
return failure();
|
||||
info.sizes.push_back(*staticSize);
|
||||
}
|
||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto staticStride = getConstantIntValue(stride);
|
||||
if (!staticStride)
|
||||
return failure();
|
||||
info.strides.push_back(*staticStride);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info) {
|
||||
SmallVector<int64_t> staticOffsets;
|
||||
staticOffsets.reserve(info.offsets.size());
|
||||
for (OpFoldResult offset : info.offsets) {
|
||||
auto staticOffset = getConstantIntValue(offset);
|
||||
if (!staticOffset)
|
||||
return failure();
|
||||
staticOffsets.push_back(*staticOffset);
|
||||
}
|
||||
return staticOffsets;
|
||||
}
|
||||
|
||||
bool isMemRefBaseAddressableValue(Value value) {
|
||||
value = stripMemRefAddressingOps(value);
|
||||
if (isa<BlockArgument>(value))
|
||||
return true;
|
||||
|
||||
Operation* defOp = value.getDefiningOp();
|
||||
return defOp && isa<memref::AllocOp, memref::GetGlobalOp>(defOp);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,34 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/LogicalResult.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct StaticSubviewInfo {
|
||||
mlir::Value source;
|
||||
llvm::SmallVector<int64_t> sourceShape;
|
||||
llvm::SmallVector<mlir::OpFoldResult> offsets;
|
||||
llvm::SmallVector<int64_t> sizes;
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
};
|
||||
|
||||
mlir::Value stripMemRefCasts(mlir::Value value);
|
||||
|
||||
mlir::Value stripMemRefViewOps(mlir::Value value);
|
||||
|
||||
mlir::Value stripMemRefAddressingOps(mlir::Value value);
|
||||
|
||||
bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview);
|
||||
|
||||
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||
|
||||
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
|
||||
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
|
||||
|
||||
bool isMemRefBaseAddressableValue(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,316 @@
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool hasWeightAlways(mlir::Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
||||
|
||||
void markWeightAlways(mlir::Operation* op) {
|
||||
assert(op && "expected valid op");
|
||||
op->setAttr(PimWeightAlwaysAttrName, mlir::UnitAttr::get(op->getContext()));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
CompiledIndexExpr makeConstantExpr(int64_t constant) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = constant;
|
||||
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::move(expr)));
|
||||
}
|
||||
|
||||
CompiledIndexExpr makeBinaryExpr(CompiledIndexExprNode::Kind kind, CompiledIndexExpr lhs, CompiledIndexExpr rhs) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = kind;
|
||||
expr.operands = {std::move(lhs), std::move(rhs)};
|
||||
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::move(expr)));
|
||||
}
|
||||
|
||||
CompiledIndexExpr addExpr(CompiledIndexExpr lhs, CompiledIndexExpr rhs) {
|
||||
return makeBinaryExpr(CompiledIndexExprNode::Kind::Add, std::move(lhs), std::move(rhs));
|
||||
}
|
||||
|
||||
CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
|
||||
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
|
||||
}
|
||||
|
||||
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefTypeStrides(mlir::MemRefType type) {
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
if (failed(type.getStridesAndOffset(strides, offset)))
|
||||
return mlir::failure();
|
||||
if (llvm::is_contained(strides, mlir::ShapedType::kDynamic))
|
||||
return mlir::failure();
|
||||
return strides;
|
||||
}
|
||||
|
||||
template <typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
if (!weightArg)
|
||||
return false;
|
||||
bool found = false;
|
||||
parentOp.walk([&](mlir::Operation* op) {
|
||||
if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
||||
found |= vmmOp.getWeight() == *weightArg;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
template <typename VMMOpTy, typename ParentOpTy>
|
||||
void walkVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||
auto weights = parentOp.getWeights();
|
||||
llvm::SmallSet<unsigned, 8> visited;
|
||||
auto walkWeight = [&](mlir::Value weight) {
|
||||
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
|
||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
if (!weightArg || *weightArg != weight)
|
||||
continue;
|
||||
if (visited.insert(weightIndex).second)
|
||||
callback(parentOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
|
||||
mlir::Operation* user = use.getOwner();
|
||||
unsigned operandIndex = use.getOperandNumber();
|
||||
|
||||
auto computeOp = mlir::dyn_cast<spatial::SpatCompute>(user);
|
||||
if (!computeOp || operandIndex >= computeOp.getWeights().size())
|
||||
return false;
|
||||
|
||||
return hasVmmWeightUse<spatial::SpatVMMOp>(computeOp, operandIndex);
|
||||
}
|
||||
|
||||
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
|
||||
llvm::SmallPtrSet<mlir::Value, 8> visited;
|
||||
auto walkUses = [&](mlir::Value currentValue, auto& self) -> bool {
|
||||
if (!visited.insert(currentValue).second)
|
||||
return true;
|
||||
if (currentValue.use_empty())
|
||||
return false;
|
||||
|
||||
return llvm::all_of(currentValue.getUses(), [&](mlir::OpOperand& use) {
|
||||
if (isSpatialMvmVmmWeightUse(use))
|
||||
return true;
|
||||
|
||||
mlir::Operation* user = use.getOwner();
|
||||
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(user))
|
||||
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
|
||||
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(user))
|
||||
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
|
||||
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
|
||||
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
|
||||
if (auto transposeOp = mlir::dyn_cast<mlir::linalg::TransposeOp>(user))
|
||||
return transposeOp.getInput() == currentValue && self(transposeOp.getResult()[0], self);
|
||||
|
||||
return false;
|
||||
});
|
||||
};
|
||||
|
||||
return walkUses(value, walkUses);
|
||||
}
|
||||
|
||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||
assert(root && "expected valid root op");
|
||||
root->walk([&](pim::PimCoreOp coreOp) {
|
||||
coreOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
if (auto weightIndex = resolveWeightIndex(coreOp.getOperation(), vmmOp.getWeight()))
|
||||
callback(coreOp->getOpOperand(*weightIndex));
|
||||
});
|
||||
});
|
||||
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
if (auto weightIndex = resolveWeightIndex(coreBatchOp.getOperation(), vmmOp.getWeight()))
|
||||
callback(coreBatchOp->getOpOperand(*weightIndex));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) {
|
||||
weight = stripMemRefAddressingOps(weight);
|
||||
|
||||
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||
if (coreOp.getWeightArgument(weightIndex) == weight)
|
||||
return weightIndex;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
|
||||
if (coreBatchOp.getWeightArgument(weightIndex) == weight)
|
||||
return weightIndex;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
llvm::FailureOr<ResolvedWeightView>
|
||||
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge) {
|
||||
llvm::SmallVector<mlir::Operation*> viewOps;
|
||||
mlir::Value current = weight;
|
||||
|
||||
while (true) {
|
||||
if (mlir::Value directAlias = knowledge.aliases.lookup(current); directAlias && directAlias != current) {
|
||||
current = directAlias;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto defOp = current.getDefiningOp()) {
|
||||
if (auto getGlobalOp = mlir::dyn_cast<mlir::memref::GetGlobalOp>(defOp)) {
|
||||
auto moduleOp = weightOwner ? weightOwner->getParentOfType<mlir::ModuleOp>() : mlir::ModuleOp {};
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getInitialValue())
|
||||
return mlir::failure();
|
||||
|
||||
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return mlir::failure();
|
||||
|
||||
ResolvedWeightView view;
|
||||
view.globalOp = globalOp;
|
||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
|
||||
CompiledIndexExpr offsetExpr = makeConstantExpr(0);
|
||||
for (mlir::Operation* viewOp : llvm::reverse(viewOps)) {
|
||||
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(viewOp)) {
|
||||
for (auto [offset, stride, sourceStride] :
|
||||
llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||
CompiledIndexExpr offsetValue = makeConstantExpr(0);
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
|
||||
auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
|
||||
if (!intAttr)
|
||||
return mlir::failure();
|
||||
offsetValue = makeConstantExpr(intAttr.getInt());
|
||||
}
|
||||
else if (auto value = mlir::dyn_cast<mlir::Value>(offset)) {
|
||||
auto compiledOffset = compileIndexExpr(value);
|
||||
if (failed(compiledOffset))
|
||||
return mlir::failure();
|
||||
offsetValue = *compiledOffset;
|
||||
}
|
||||
else {
|
||||
return mlir::failure();
|
||||
}
|
||||
offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride));
|
||||
}
|
||||
auto resultType = mlir::cast<mlir::MemRefType>(subview.getResult().getType());
|
||||
auto resultStrides = getStaticMemRefTypeStrides(resultType);
|
||||
if (failed(resultStrides))
|
||||
return mlir::failure();
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = std::move(*resultStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(viewOp)) {
|
||||
auto resultType = mlir::cast<mlir::MemRefType>(collapse.getResult().getType());
|
||||
auto resultStrides = getStaticMemRefTypeStrides(resultType);
|
||||
if (failed(resultStrides))
|
||||
return mlir::failure();
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = std::move(*resultStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(viewOp)) {
|
||||
auto resultType = mlir::cast<mlir::MemRefType>(expand.getResult().getType());
|
||||
auto resultStrides = getStaticMemRefTypeStrides(resultType);
|
||||
if (failed(resultStrides))
|
||||
return mlir::failure();
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = std::move(*resultStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(viewOp)) {
|
||||
auto resultType = mlir::cast<mlir::MemRefType>(castOp.getResult().getType());
|
||||
auto resultStrides = getStaticMemRefTypeStrides(resultType);
|
||||
if (failed(resultStrides))
|
||||
return mlir::failure();
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = std::move(*resultStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
auto resolvedOffset = offsetExpr.evaluate(knowledge);
|
||||
if (failed(resolvedOffset))
|
||||
return mlir::failure();
|
||||
view.offset = *resolvedOffset;
|
||||
return view;
|
||||
}
|
||||
|
||||
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp)) {
|
||||
viewOps.push_back(defOp);
|
||||
current = subview.getSource();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp)) {
|
||||
viewOps.push_back(defOp);
|
||||
current = collapse.getSrc();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(defOp)) {
|
||||
viewOps.push_back(defOp);
|
||||
current = expand.getSrc();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(defOp)) {
|
||||
viewOps.push_back(defOp);
|
||||
current = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (mlir::Value loopAlias = resolveLoopCarriedAlias(current, knowledge); loopAlias && loopAlias != current) {
|
||||
current = loopAlias;
|
||||
continue;
|
||||
}
|
||||
|
||||
auto weightIndex = resolveWeightIndex(weightOwner, current);
|
||||
if (!weightIndex)
|
||||
return mlir::failure();
|
||||
|
||||
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
|
||||
current = coreOp.getWeights()[*weightIndex];
|
||||
continue;
|
||||
}
|
||||
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
|
||||
current = coreBatchOp.getWeights()[*weightIndex];
|
||||
continue;
|
||||
}
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,64 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ResolvedWeightView {
|
||||
mlir::memref::GlobalOp globalOp;
|
||||
llvm::SmallVector<int64_t> shape;
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
|
||||
bool operator==(const ResolvedWeightView& other) const {
|
||||
return globalOp == other.globalOp && shape == other.shape && strides == other.strides && offset == other.offset;
|
||||
}
|
||||
};
|
||||
|
||||
bool hasWeightAlways(mlir::Operation* op);
|
||||
|
||||
/// Tags an op as producing a value that should stay materialized as a reusable
|
||||
/// weight across later PIM lowering/codegen stages.
|
||||
void markWeightAlways(mlir::Operation* op);
|
||||
|
||||
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
|
||||
|
||||
/// Returns true when a value flows only into Spatial weighted MVM/VMM operands,
|
||||
/// allowing later passes to preserve it as a dedicated weight-like object.
|
||||
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
|
||||
|
||||
/// Visits weight operands consumed by Pim core ops/core batches so downstream
|
||||
/// passes can identify globals that must remain weight-backed.
|
||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight);
|
||||
llvm::FailureOr<ResolvedWeightView>
|
||||
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge = {});
|
||||
|
||||
template <typename CoreLikeOpTy>
|
||||
llvm::SmallVector<unsigned, 8> getUsedWeightIndices(CoreLikeOpTy coreLikeOp) {
|
||||
llvm::SmallVector<unsigned, 8> indices;
|
||||
coreLikeOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
auto weightIndex = resolveWeightIndex(coreLikeOp.getOperation(), vmmOp.getWeight());
|
||||
if (weightIndex && !llvm::is_contained(indices, *weightIndex))
|
||||
indices.push_back(*weightIndex);
|
||||
});
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,546 +0,0 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::string getOutputDir() {
|
||||
if (outputBaseName.empty() || outputBaseName == "-")
|
||||
return {};
|
||||
|
||||
size_t lastSlash = outputBaseName.find_last_of('/');
|
||||
if (lastSlash == std::string::npos)
|
||||
return ".";
|
||||
return outputBaseName.substr(0, lastSlash);
|
||||
}
|
||||
|
||||
void createDirectory(const std::string& directory) {
|
||||
std::error_code errorCode;
|
||||
std::filesystem::create_directories(directory, errorCode);
|
||||
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
||||
}
|
||||
|
||||
void dumpModule(ModuleOp moduleOp, const std::string& name) {
|
||||
std::string outputDir = getOutputDir();
|
||||
if (outputDir.empty())
|
||||
return;
|
||||
|
||||
std::string dialectsDir = outputDir + "/dialects";
|
||||
createDirectory(dialectsDir);
|
||||
|
||||
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
||||
llvm::raw_os_ostream os(file);
|
||||
os << *moduleOp;
|
||||
os.flush();
|
||||
file.close();
|
||||
}
|
||||
|
||||
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
|
||||
if (!moduleOp)
|
||||
return failure();
|
||||
|
||||
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
|
||||
if (entryPoints.size() > 1) {
|
||||
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
|
||||
return failure();
|
||||
}
|
||||
if (!entryPoints.empty()) {
|
||||
auto entryPointAttr =
|
||||
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
|
||||
if (!entryPointAttr) {
|
||||
entryPoints.front().emitOpError("is missing the entry point function attribute");
|
||||
return failure();
|
||||
}
|
||||
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
|
||||
if (!entryFunc) {
|
||||
entryPoints.front().emitOpError("references an unknown entry function ")
|
||||
<< entryPointAttr.getLeafReference().getValue();
|
||||
return failure();
|
||||
}
|
||||
return entryFunc;
|
||||
}
|
||||
|
||||
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
|
||||
return mainGraphFunc;
|
||||
|
||||
SmallVector<func::FuncOp> nonExternalFuncs;
|
||||
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
|
||||
if (!funcOp.isExternal())
|
||||
nonExternalFuncs.push_back(funcOp);
|
||||
if (nonExternalFuncs.size() == 1)
|
||||
return nonExternalFuncs.front();
|
||||
|
||||
moduleOp.emitError("could not resolve a unique PIM entry function");
|
||||
return failure();
|
||||
}
|
||||
|
||||
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
|
||||
|
||||
void markWeightAlways(Operation* op) {
|
||||
assert(op && "expected valid op");
|
||||
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
|
||||
}
|
||||
|
||||
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
|
||||
if (!moduleOp || !getGlobalOp)
|
||||
return {};
|
||||
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
|
||||
}
|
||||
|
||||
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
|
||||
|
||||
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
|
||||
if (!channelNewOp) {
|
||||
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
|
||||
return failure();
|
||||
}
|
||||
// channelNewOp should have two users: `op` and a
|
||||
// `ChannelSendOp`/`ChannelReceiveOp`
|
||||
auto channelUsers = channelNewOp->getUsers();
|
||||
auto usersIterator = channelUsers.begin();
|
||||
auto firstUser = *usersIterator;
|
||||
usersIterator++;
|
||||
if (usersIterator == channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||
"only one found.");
|
||||
channelNewOp->dump();
|
||||
op->dump();
|
||||
channelNewOp->getParentOp()->dump();
|
||||
return failure();
|
||||
}
|
||||
auto secondUser = *usersIterator;
|
||||
usersIterator++;
|
||||
if (usersIterator != channelUsers.end()) {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||
"more than two found.");
|
||||
return failure();
|
||||
}
|
||||
Operation* notOpUser;
|
||||
if (firstUser == op) {
|
||||
notOpUser = secondUser;
|
||||
}
|
||||
else if (secondUser == op) {
|
||||
notOpUser = firstUser;
|
||||
}
|
||||
else {
|
||||
op->emitError("Operand generated by ChannelNewOp must have two users, "
|
||||
"and one of them must be me, but"
|
||||
"none of them is actually me.");
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (opIsReceive) {
|
||||
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
||||
"me, the other is not a ChannelSendOp.");
|
||||
return failure();
|
||||
}
|
||||
return notOpUser;
|
||||
}
|
||||
else {
|
||||
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
|
||||
op->emitError("Operand generated by ChannelNewOp has two user, one is "
|
||||
"me, the other is not a ChannelReceiveOp.");
|
||||
return failure();
|
||||
}
|
||||
return notOpUser;
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
|
||||
SmallVector<int64_t> indices(shape.size(), 0);
|
||||
for (auto [dim, stride] : llvm::enumerate(strides)) {
|
||||
indices[dim] = linearIndex / stride;
|
||||
linearIndex %= stride;
|
||||
}
|
||||
return indices;
|
||||
}
|
||||
|
||||
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
|
||||
int64_t linearIndex = 0;
|
||||
for (auto [index, stride] : llvm::zip_equal(indices, strides))
|
||||
linearIndex += index * stride;
|
||||
return linearIndex;
|
||||
}
|
||||
|
||||
int64_t getNumElements(ArrayRef<int64_t> shape) {
|
||||
int64_t numElements = 1;
|
||||
for (int64_t dim : shape)
|
||||
numElements *= dim;
|
||||
return numElements;
|
||||
}
|
||||
|
||||
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
|
||||
ArrayRef<int64_t> offsets,
|
||||
ArrayRef<int64_t> sizes,
|
||||
ArrayRef<int64_t> strides) {
|
||||
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
|
||||
return false;
|
||||
|
||||
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
|
||||
llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||
|
||||
auto firstNonZeroOffset = std::find_if(
|
||||
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
|
||||
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
|
||||
return offset != 0;
|
||||
});
|
||||
|
||||
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
|
||||
auto [offset, size, dimension] = *firstNonZeroOffset;
|
||||
if (size > dimension - offset)
|
||||
return false;
|
||||
++firstNonZeroOffset;
|
||||
|
||||
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
|
||||
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
|
||||
return size != 1;
|
||||
}))
|
||||
return false;
|
||||
}
|
||||
|
||||
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
|
||||
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
|
||||
|
||||
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
|
||||
auto [size, dimension] = sizeAndShape;
|
||||
return size != dimension;
|
||||
});
|
||||
|
||||
if (firstDifferentSize != sizesAndShape.end()) {
|
||||
++firstDifferentSize;
|
||||
|
||||
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
|
||||
auto [size, _dimension] = sizeAndShape;
|
||||
return size != 1;
|
||||
}))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static Value resolveAlias(Value value, const StaticValueKnowledge* knowledge) {
|
||||
if (!knowledge)
|
||||
return value;
|
||||
|
||||
auto iter = knowledge->aliases.find(value);
|
||||
while (iter != knowledge->aliases.end()) {
|
||||
value = iter->second;
|
||||
iter = knowledge->aliases.find(value);
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
// Walks through view-like ops and DPS tied operands to find the "underlying" memref value
|
||||
// behind an scf.for iter-arg. Used both when resolving a contiguous address inside a loop
|
||||
// and when propagating yielded values across iterations during static unrolling.
|
||||
static Value resolveLoopCarriedAliasImpl(Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
if (auto blockArgument = dyn_cast<BlockArgument>(value))
|
||||
return value;
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return value;
|
||||
|
||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
||||
if (auto result = dyn_cast<OpResult>(value))
|
||||
if (OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
|
||||
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
|
||||
}
|
||||
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp))
|
||||
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||
|
||||
static FailureOr<int64_t> resolveIndexValueImpl(Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
if (knowledge) {
|
||||
auto iter = knowledge->indexValues.find(value);
|
||||
if (iter != knowledge->indexValues.end())
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
|
||||
if (constantOp) {
|
||||
if (auto integerAttr = dyn_cast<IntegerAttr>(constantOp.getValue()))
|
||||
return integerAttr.getInt();
|
||||
}
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
|
||||
if (auto indexCastOp = dyn_cast<arith::IndexCastOp>(definingOp))
|
||||
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
|
||||
|
||||
if (auto addOp = dyn_cast<arith::AddIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return failure();
|
||||
return *lhs + *rhs;
|
||||
}
|
||||
|
||||
if (auto subOp = dyn_cast<arith::SubIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return failure();
|
||||
return *lhs - *rhs;
|
||||
}
|
||||
|
||||
if (auto mulOp = dyn_cast<arith::MulIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return failure();
|
||||
return *lhs * *rhs;
|
||||
}
|
||||
|
||||
if (auto divOp = dyn_cast<arith::DivUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto remOp = dyn_cast<arith::RemUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
|
||||
if (auto attr = dyn_cast<Attribute>(ofr)) {
|
||||
auto integerAttr = dyn_cast<IntegerAttr>(attr);
|
||||
if (!integerAttr)
|
||||
return failure();
|
||||
return integerAttr.getInt();
|
||||
}
|
||||
|
||||
return resolveIndexValueImpl(cast<Value>(ofr), knowledge);
|
||||
}
|
||||
|
||||
static FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(Value value,
|
||||
const StaticValueKnowledge* knowledge) {
|
||||
int64_t byteOffset = 0;
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
while (true) {
|
||||
if (isa<BlockArgument>(value))
|
||||
return ResolvedContiguousAddress {value, byteOffset};
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
|
||||
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
|
||||
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
return failure();
|
||||
value = resolveAlias(tiedOperand->get(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(definingOp)) {
|
||||
auto result = dyn_cast<OpResult>(value);
|
||||
if (!result)
|
||||
return failure();
|
||||
|
||||
// Trace the loop carry back to its underlying memref, then if that memref is the
|
||||
// loop's own iter-arg we know the base comes from the corresponding init arg
|
||||
// (every iteration yields the same backing memory in the DPS sense).
|
||||
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
|
||||
if (auto blockArgument = dyn_cast<BlockArgument>(yieldedValue)) {
|
||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
||||
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
value = yieldedValue;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
|
||||
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
|
||||
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> offsets;
|
||||
SmallVector<int64_t> sizes;
|
||||
SmallVector<int64_t> strides;
|
||||
offsets.reserve(subviewOp.getMixedOffsets().size());
|
||||
sizes.reserve(subviewOp.getMixedSizes().size());
|
||||
strides.reserve(subviewOp.getMixedStrides().size());
|
||||
|
||||
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
|
||||
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
|
||||
if (failed(resolvedOffset))
|
||||
return failure();
|
||||
offsets.push_back(*resolvedOffset);
|
||||
}
|
||||
|
||||
for (OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto resolvedSize = resolveOpFoldResult(size, knowledge);
|
||||
if (failed(resolvedSize))
|
||||
return failure();
|
||||
sizes.push_back(*resolvedSize);
|
||||
}
|
||||
|
||||
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
|
||||
if (failed(resolvedStride))
|
||||
return failure();
|
||||
strides.push_back(*resolvedStride);
|
||||
}
|
||||
|
||||
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||
return failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
||||
value = resolveAlias(subviewOp.getSource(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
|
||||
value = resolveAlias(castOp.getSource(), knowledge);
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
|
||||
value = resolveAlias(collapseOp.getSrc(), knowledge);
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
|
||||
value = resolveAlias(expandOp.getSrc(), knowledge);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
|
||||
return ResolvedContiguousAddress {value, byteOffset};
|
||||
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
FailureOr<int64_t> resolveIndexValue(Value value) { return resolveIndexValueImpl(value, nullptr); }
|
||||
|
||||
FailureOr<int64_t> resolveIndexValue(Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveIndexValueImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
|
||||
return resolveContiguousAddressImpl(value, nullptr);
|
||||
}
|
||||
|
||||
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveContiguousAddressImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
Value resolveLoopCarriedAlias(Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveLoopCarriedAliasImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
bool isCoreStaticAddressOp(Operation* op) {
|
||||
return isa<arith::ConstantOp,
|
||||
arith::AddIOp,
|
||||
arith::SubIOp,
|
||||
arith::MulIOp,
|
||||
arith::DivUIOp,
|
||||
arith::RemUIOp,
|
||||
arith::IndexCastOp,
|
||||
memref::AllocOp,
|
||||
memref::SubViewOp,
|
||||
memref::CastOp,
|
||||
memref::CollapseShapeOp,
|
||||
memref::ExpandShapeOp>(op);
|
||||
}
|
||||
|
||||
LogicalResult walkPimCoreBlock(Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<LogicalResult(Operation&, const StaticValueKnowledge&)> callback) {
|
||||
bool hasFailure = false;
|
||||
for (Operation& op : block) {
|
||||
if (isa<pim::PimHaltOp, scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||
continue;
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
|
||||
Block& loopBody = forOp.getRegion().front();
|
||||
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
|
||||
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
|
||||
auto step = resolveIndexValue(forOp.getStep(), knowledge);
|
||||
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
|
||||
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
|
||||
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
|
||||
StaticValueKnowledge loopKnowledge = knowledge;
|
||||
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
|
||||
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
|
||||
loopKnowledge.aliases[iterArg] = iterValue;
|
||||
|
||||
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
|
||||
hasFailure = true;
|
||||
|
||||
auto yieldOp = cast<scf::YieldOp>(loopBody.getTerminator());
|
||||
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
|
||||
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (failed(callback(op, knowledge)))
|
||||
hasFailure = true;
|
||||
}
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -7,82 +7,23 @@
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ResolvedContiguousAddress {
|
||||
mlir::Value base;
|
||||
int64_t byteOffset = 0;
|
||||
};
|
||||
|
||||
struct StaticValueKnowledge {
|
||||
llvm::DenseMap<mlir::Value, int64_t> indexValues;
|
||||
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
|
||||
|
||||
StaticValueKnowledge() {}
|
||||
};
|
||||
|
||||
std::string getOutputDir();
|
||||
|
||||
void createDirectory(const std::string& directory);
|
||||
|
||||
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
|
||||
|
||||
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
|
||||
|
||||
bool hasWeightAlways(mlir::Operation* op);
|
||||
|
||||
void markWeightAlways(mlir::Operation* op);
|
||||
|
||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||
|
||||
llvm::FailureOr<mlir::Operation*>
|
||||
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
|
||||
|
||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
llvm::SmallVector<int64_t>
|
||||
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
llvm::ArrayRef<int64_t> offsets,
|
||||
llvm::ArrayRef<int64_t> sizes,
|
||||
llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
||||
const StaticValueKnowledge& knowledge);
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||
|
||||
/// Follows alias and view/DPS chains using `knowledge` to find the value an scf.for
|
||||
/// iter-arg is ultimately backed by. Used when interpreting scf.for loop carries.
|
||||
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||
|
||||
/// Returns true for ops inside a pim.core body that do not emit any PIM instruction and
|
||||
/// only contribute to static addressing or index computations (arith integer math,
|
||||
/// memref view ops, memref.alloc, arith.constant).
|
||||
bool isCoreStaticAddressOp(mlir::Operation* op);
|
||||
|
||||
/// Walks `block` (the body of a pim.core region or an scf.for nested in it), statically
|
||||
/// unrolling any scf.for with resolvable bounds using `knowledge`. For each remaining op
|
||||
/// that is not skipped (pim.halt, scf.yield, or isCoreStaticAddressOp), `callback` is
|
||||
/// invoked with the op and the in-scope knowledge. The walker keeps going after a callback
|
||||
/// failure so callers can collect multiple diagnostics, but propagates the overall result.
|
||||
mlir::LogicalResult
|
||||
walkPimCoreBlock(mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
||||
inline constexpr llvm::StringLiteral kCoreIdAttrName = "coreId";
|
||||
inline constexpr llvm::StringLiteral kCoreIdsAttrName = "coreIds";
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -0,0 +1,222 @@
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include "CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
namespace {
|
||||
|
||||
static void emitCrashMessage(llvm::StringRef fieldName, llvm::StringRef message) {
|
||||
llvm::errs() << "PIM " << fieldName << " " << message << "\n";
|
||||
}
|
||||
|
||||
template <typename To, typename From>
|
||||
static FailureOr<To> checkedCastAtLocation(From value, Location loc, llvm::StringRef fieldName) {
|
||||
static_assert(std::is_integral_v<To> && std::is_integral_v<From>, "checkedCastAtLocation requires integral types");
|
||||
|
||||
using ToLimits = std::numeric_limits<To>;
|
||||
|
||||
if constexpr (std::is_signed_v<From> == std::is_signed_v<To>) {
|
||||
if (value < static_cast<From>(ToLimits::min()) || value > static_cast<From>(ToLimits::max())) {
|
||||
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_signed_v<From>) {
|
||||
using UnsignedFrom = std::make_unsigned_t<From>;
|
||||
using UnsignedTo = std::make_unsigned_t<To>;
|
||||
if (value < 0 || static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
|
||||
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
else {
|
||||
using UnsignedFrom = std::make_unsigned_t<From>;
|
||||
using UnsignedTo = std::conditional_t<std::is_signed_v<To>, std::make_unsigned_t<To>, To>;
|
||||
if (static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
|
||||
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
return static_cast<To>(value);
|
||||
}
|
||||
|
||||
template <typename UInt>
|
||||
FailureOr<UInt> checkedMulAtLocation(UInt lhs, UInt rhs, Location loc, llvm::StringRef fieldName) {
|
||||
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>,
|
||||
"checkedMulAtLocation requires unsigned integral types");
|
||||
if (lhs != 0 && rhs > std::numeric_limits<UInt>::max() / lhs) {
|
||||
emitCheckedArithmeticError(loc, fieldName, "multiplication overflow");
|
||||
return failure();
|
||||
}
|
||||
return lhs * rhs;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
InFlightDiagnostic emitCheckedArithmeticError(Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message) {
|
||||
assert(anchor && "expected arithmetic diagnostics to have an anchor op");
|
||||
return anchor->emitOpError() << fieldName << " " << message;
|
||||
}
|
||||
|
||||
InFlightDiagnostic emitCheckedArithmeticError(Location loc, llvm::StringRef fieldName, llvm::StringRef message) {
|
||||
return emitError(loc) << "PIM " << fieldName << " " << message;
|
||||
}
|
||||
|
||||
FailureOr<int32_t> checkedI32(int64_t value, Operation* anchor, llvm::StringRef fieldName) {
|
||||
return checkedCast<int32_t>(value, anchor, fieldName);
|
||||
}
|
||||
|
||||
FailureOr<int32_t> checkedI32(uint64_t value, Operation* anchor, llvm::StringRef fieldName) {
|
||||
return checkedCast<int32_t>(value, anchor, fieldName);
|
||||
}
|
||||
|
||||
FailureOr<uint8_t> checkedU8(uint64_t value, Operation* anchor, llvm::StringRef fieldName) {
|
||||
return checkedCast<uint8_t>(value, anchor, fieldName);
|
||||
}
|
||||
|
||||
FailureOr<size_t> checkedSize(int64_t value, Operation* anchor, llvm::StringRef fieldName) {
|
||||
return checkedCast<size_t>(value, anchor, fieldName);
|
||||
}
|
||||
|
||||
FailureOr<IntegerAttr>
|
||||
getCheckedI32Attr(Builder& builder, Operation* anchor, int64_t value, llvm::StringRef fieldName) {
|
||||
assert(anchor && "checked op-based attrs require a non-null diagnostic anchor");
|
||||
auto checkedValue = checkedI32(value, anchor, fieldName);
|
||||
if (failed(checkedValue))
|
||||
return failure();
|
||||
return builder.getI32IntegerAttr(*checkedValue);
|
||||
}
|
||||
|
||||
FailureOr<IntegerAttr>
|
||||
getCheckedI32Attr(Builder& builder, Operation* anchor, uint64_t value, llvm::StringRef fieldName) {
|
||||
assert(anchor && "checked op-based attrs require a non-null diagnostic anchor");
|
||||
auto checkedValue = checkedI32(value, anchor, fieldName);
|
||||
if (failed(checkedValue))
|
||||
return failure();
|
||||
return builder.getI32IntegerAttr(*checkedValue);
|
||||
}
|
||||
|
||||
FailureOr<IntegerAttr> getCheckedI32Attr(Builder& builder, Location loc, int64_t value, llvm::StringRef fieldName) {
|
||||
auto checkedValue = checkedCastAtLocation<int32_t>(value, loc, fieldName);
|
||||
if (failed(checkedValue))
|
||||
return failure();
|
||||
return builder.getI32IntegerAttr(*checkedValue);
|
||||
}
|
||||
|
||||
FailureOr<IntegerAttr> getCheckedI32Attr(Builder& builder, Location loc, uint64_t value, llvm::StringRef fieldName) {
|
||||
auto checkedValue = checkedCastAtLocation<int32_t>(value, loc, fieldName);
|
||||
if (failed(checkedValue))
|
||||
return failure();
|
||||
return builder.getI32IntegerAttr(*checkedValue);
|
||||
}
|
||||
|
||||
FailureOr<uint64_t> getCheckedShapedTypeSizeInBytes(ShapedType type, Operation* anchor, llvm::StringRef fieldName) {
|
||||
assert(anchor && "checked op-based size helpers require a non-null diagnostic anchor");
|
||||
if (!type.hasStaticShape()) {
|
||||
emitCheckedArithmeticError(anchor, fieldName, "requires static shaped type");
|
||||
return failure();
|
||||
}
|
||||
if (!hasByteSizedElementType(type.getElementType())) {
|
||||
emitCheckedArithmeticError(anchor, fieldName, "requires byte-sized element type");
|
||||
return failure();
|
||||
}
|
||||
|
||||
uint64_t elements = 1;
|
||||
for (int64_t dim : type.getShape()) {
|
||||
if (dim < 0) {
|
||||
emitCheckedArithmeticError(anchor, fieldName, "requires nonnegative dimensions");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto nextElements = checkedMul(elements, static_cast<uint64_t>(dim), anchor, fieldName);
|
||||
if (failed(nextElements))
|
||||
return failure();
|
||||
elements = *nextElements;
|
||||
}
|
||||
|
||||
return checkedMul(
|
||||
elements, static_cast<uint64_t>(getElementTypeSizeInBytes(type.getElementType())), anchor, fieldName);
|
||||
}
|
||||
|
||||
FailureOr<uint64_t> getCheckedShapedTypeSizeInBytes(ShapedType type, Location loc, llvm::StringRef fieldName) {
|
||||
if (!type.hasStaticShape()) {
|
||||
emitCheckedArithmeticError(loc, fieldName, "requires static shaped type");
|
||||
return failure();
|
||||
}
|
||||
if (!hasByteSizedElementType(type.getElementType())) {
|
||||
emitCheckedArithmeticError(loc, fieldName, "requires byte-sized element type");
|
||||
return failure();
|
||||
}
|
||||
|
||||
uint64_t elements = 1;
|
||||
for (int64_t dim : type.getShape()) {
|
||||
if (dim < 0) {
|
||||
emitCheckedArithmeticError(loc, fieldName, "requires nonnegative dimensions");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto nextElements = checkedMulAtLocation(elements, static_cast<uint64_t>(dim), loc, fieldName);
|
||||
if (failed(nextElements))
|
||||
return failure();
|
||||
elements = *nextElements;
|
||||
}
|
||||
|
||||
return checkedMulAtLocation(
|
||||
elements, static_cast<uint64_t>(getElementTypeSizeInBytes(type.getElementType())), loc, fieldName);
|
||||
}
|
||||
|
||||
int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName) {
|
||||
if (value < std::numeric_limits<int32_t>::min() || value > std::numeric_limits<int32_t>::max()) {
|
||||
emitCrashMessage(fieldName, "is outside representable range");
|
||||
llvm_unreachable("PIM checked arithmetic failure");
|
||||
}
|
||||
return static_cast<int32_t>(value);
|
||||
}
|
||||
|
||||
int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName) {
|
||||
if (value > static_cast<uint64_t>(std::numeric_limits<int32_t>::max())) {
|
||||
emitCrashMessage(fieldName, "is outside representable range");
|
||||
llvm_unreachable("PIM checked arithmetic failure");
|
||||
}
|
||||
return static_cast<int32_t>(value);
|
||||
}
|
||||
|
||||
uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName) {
|
||||
if (value > static_cast<uint64_t>(std::numeric_limits<uint8_t>::max())) {
|
||||
emitCrashMessage(fieldName, "is outside representable range");
|
||||
llvm_unreachable("PIM checked arithmetic failure");
|
||||
}
|
||||
return static_cast<uint8_t>(value);
|
||||
}
|
||||
|
||||
size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName) {
|
||||
if (value < 0) {
|
||||
emitCrashMessage(fieldName, "is outside representable range");
|
||||
llvm_unreachable("PIM checked arithmetic failure");
|
||||
}
|
||||
return static_cast<size_t>(value);
|
||||
}
|
||||
|
||||
size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName) {
|
||||
if (rhs > std::numeric_limits<size_t>::max() - lhs) {
|
||||
emitCrashMessage(fieldName, "addition overflow");
|
||||
llvm_unreachable("PIM checked arithmetic failure");
|
||||
}
|
||||
return lhs + rhs;
|
||||
}
|
||||
|
||||
size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName) {
|
||||
if (lhs != 0 && rhs > std::numeric_limits<size_t>::max() / lhs) {
|
||||
emitCrashMessage(fieldName, "multiplication overflow");
|
||||
llvm_unreachable("PIM checked arithmetic failure");
|
||||
}
|
||||
return lhs * rhs;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::pim
|
||||
@@ -0,0 +1,107 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
mlir::InFlightDiagnostic
|
||||
emitCheckedArithmeticError(mlir::Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message);
|
||||
|
||||
mlir::InFlightDiagnostic
|
||||
emitCheckedArithmeticError(mlir::Location loc, llvm::StringRef fieldName, llvm::StringRef message);
|
||||
|
||||
template <typename To, typename From>
|
||||
mlir::FailureOr<To> checkedCast(From value, mlir::Operation* anchor, llvm::StringRef fieldName) {
|
||||
static_assert(std::is_integral_v<To> && std::is_integral_v<From>, "checkedCast requires integral types");
|
||||
|
||||
using ToLimits = std::numeric_limits<To>;
|
||||
|
||||
if constexpr (std::is_signed_v<From> == std::is_signed_v<To>) {
|
||||
if (value < static_cast<From>(ToLimits::min()) || value > static_cast<From>(ToLimits::max())) {
|
||||
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_signed_v<From>) {
|
||||
using UnsignedFrom = std::make_unsigned_t<From>;
|
||||
using UnsignedTo = std::make_unsigned_t<To>;
|
||||
if (value < 0 || static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
|
||||
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
else {
|
||||
using UnsignedFrom = std::make_unsigned_t<From>;
|
||||
using UnsignedTo = std::conditional_t<std::is_signed_v<To>, std::make_unsigned_t<To>, To>;
|
||||
if (static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
|
||||
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
return static_cast<To>(value);
|
||||
}
|
||||
|
||||
template <typename UInt>
|
||||
mlir::FailureOr<UInt> checkedAdd(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) {
|
||||
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>, "checkedAdd requires unsigned integral types");
|
||||
if (rhs > std::numeric_limits<UInt>::max() - lhs) {
|
||||
emitCheckedArithmeticError(anchor, fieldName, "addition overflow");
|
||||
return mlir::failure();
|
||||
}
|
||||
return lhs + rhs;
|
||||
}
|
||||
|
||||
template <typename UInt>
|
||||
mlir::FailureOr<UInt> checkedMul(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) {
|
||||
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>, "checkedMul requires unsigned integral types");
|
||||
if (lhs != 0 && rhs > std::numeric_limits<UInt>::max() / lhs) {
|
||||
emitCheckedArithmeticError(anchor, fieldName, "multiplication overflow");
|
||||
return mlir::failure();
|
||||
}
|
||||
return lhs * rhs;
|
||||
}
|
||||
|
||||
mlir::FailureOr<int32_t> checkedI32(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
|
||||
mlir::FailureOr<int32_t> checkedI32(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<uint8_t> checkedU8(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<size_t> checkedSize(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<mlir::IntegerAttr>
|
||||
getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, int64_t value, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<mlir::IntegerAttr>
|
||||
getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, uint64_t value, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<mlir::IntegerAttr>
|
||||
getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, int64_t value, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<mlir::IntegerAttr>
|
||||
getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, uint64_t value, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<uint64_t>
|
||||
getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Operation* anchor, llvm::StringRef fieldName);
|
||||
|
||||
mlir::FailureOr<uint64_t>
|
||||
getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Location loc, llvm::StringRef fieldName);
|
||||
|
||||
int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName);
|
||||
int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName);
|
||||
uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName);
|
||||
size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName);
|
||||
size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName);
|
||||
size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName);
|
||||
|
||||
} // namespace onnx_mlir::pim
|
||||
@@ -0,0 +1,27 @@
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
|
||||
std::string outputDir = getOutputDir();
|
||||
if (outputDir.empty())
|
||||
return;
|
||||
|
||||
std::string dialectsDir = outputDir + "/dialects";
|
||||
createDirectory(dialectsDir);
|
||||
|
||||
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
||||
llvm::raw_os_ostream os(file);
|
||||
mlir::OpPrintingFlags flags;
|
||||
flags.elideLargeElementsAttrs().enableDebugInfo(true, false);
|
||||
moduleOp.print(os, flags);
|
||||
os.flush();
|
||||
file.close();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/// Emits a MLIR snapshot under the current compiler output
|
||||
/// directory for pass-level debugging.
|
||||
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,41 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription) {
|
||||
return op->emitOpError() << "requires statically shaped " << valueDescription;
|
||||
}
|
||||
|
||||
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
|
||||
llvm::StringRef valueDescription,
|
||||
int64_t actualRank,
|
||||
llvm::ArrayRef<int64_t> supportedRanks) {
|
||||
auto diag = op->emitOpError() << "has unsupported rank " << actualRank << " for " << valueDescription;
|
||||
if (supportedRanks.empty())
|
||||
return diag;
|
||||
|
||||
diag << "; supported rank";
|
||||
if (supportedRanks.size() != 1)
|
||||
diag << 's';
|
||||
diag << ' ';
|
||||
|
||||
llvm::interleaveComma(supportedRanks, diag, [&](int64_t rank) { diag << rank; });
|
||||
return diag;
|
||||
}
|
||||
|
||||
mlir::InFlightDiagnostic
|
||||
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName) {
|
||||
return op->emitOpError() << "references missing " << symbolKind << " `" << symbolName << "`";
|
||||
}
|
||||
|
||||
mlir::LogicalResult emitFileSystemError(mlir::Location loc,
|
||||
llvm::StringRef action,
|
||||
llvm::StringRef path,
|
||||
const std::error_code& errorCode) {
|
||||
mlir::emitError(loc) << "failed to " << action << " `" << path << "`: " << errorCode.message();
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::pim
|
||||
@@ -0,0 +1,64 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <system_error>
|
||||
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
struct CappedDiagnosticReporter {
|
||||
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8)
|
||||
: maxReportedFailures(maxReportedFailures) {}
|
||||
|
||||
template <typename EmitFn>
|
||||
void report(mlir::Operation* op, EmitFn&& emit) {
|
||||
numFailures++;
|
||||
if (numFailures <= maxReportedFailures)
|
||||
emit(op);
|
||||
}
|
||||
|
||||
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
|
||||
if (numFailures > maxReportedFailures)
|
||||
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
|
||||
}
|
||||
|
||||
void noteFailures(int64_t count) { numFailures += count; }
|
||||
|
||||
bool hasFailure() const { return numFailures != 0; }
|
||||
|
||||
private:
|
||||
int64_t maxReportedFailures;
|
||||
int64_t numFailures = 0;
|
||||
};
|
||||
|
||||
/// Emits a consistent diagnostic for target paths that require static shapes.
|
||||
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
|
||||
|
||||
/// Emits a consistent diagnostic for unsupported ranks while listing the ranks
|
||||
/// accepted by the current lowering/codegen path.
|
||||
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
|
||||
llvm::StringRef valueDescription,
|
||||
int64_t actualRank,
|
||||
llvm::ArrayRef<int64_t> supportedRanks);
|
||||
|
||||
/// Emits a consistent diagnostic for missing symbol/global references.
|
||||
mlir::InFlightDiagnostic
|
||||
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName);
|
||||
|
||||
/// Converts a filesystem error into an MLIR failure diagnostic anchored at
|
||||
/// the relevant IR location.
|
||||
mlir::LogicalResult
|
||||
emitFileSystemError(mlir::Location loc, llvm::StringRef action, llvm::StringRef path, const std::error_code& errorCode);
|
||||
|
||||
template <typename T>
|
||||
mlir::LogicalResult failureOrToLogicalResult(const llvm::FailureOr<T>& value) {
|
||||
return mlir::success(succeeded(value));
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::pim
|
||||
@@ -0,0 +1,24 @@
|
||||
#include <filesystem>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::string getOutputDir() {
|
||||
if (outputBaseName.empty() || outputBaseName == "-")
|
||||
return {};
|
||||
|
||||
size_t lastSlash = outputBaseName.find_last_of('/');
|
||||
if (lastSlash == std::string::npos)
|
||||
return ".";
|
||||
return outputBaseName.substr(0, lastSlash);
|
||||
}
|
||||
|
||||
void createDirectory(const std::string& directory) {
|
||||
std::error_code errorCode;
|
||||
std::filesystem::create_directories(directory, errorCode);
|
||||
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/// Returns the directory that should hold PIM artifacts/debug dumps for the
|
||||
/// current compiler invocation.
|
||||
std::string getOutputDir();
|
||||
|
||||
void createDirectory(const std::string& directory);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,64 @@
|
||||
#include "llvm/Support/Format.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRef extension) {
|
||||
std::string outputDir = getOutputDir();
|
||||
if (outputDir.empty())
|
||||
return {};
|
||||
|
||||
std::string reportsDir = outputDir + "/reports";
|
||||
createDirectory(reportsDir);
|
||||
return std::fstream(reportsDir + "/" + name + "." + extension.str(), std::ios::out);
|
||||
}
|
||||
|
||||
std::fstream openReportFile(const std::string& name) { return openReportFileWithExtension(name, "txt"); }
|
||||
|
||||
std::string formatReportMemory(uint64_t bytes) {
|
||||
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
|
||||
int i = 0;
|
||||
double size = static_cast<double>(bytes);
|
||||
while (size >= 1024 && i < 6) {
|
||||
size /= 1024;
|
||||
i++;
|
||||
}
|
||||
|
||||
std::string out;
|
||||
llvm::raw_string_ostream rss(out);
|
||||
rss << llvm::format("%.2f ", size) << units[i];
|
||||
return rss.str();
|
||||
}
|
||||
|
||||
void printReportFlatFields(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields) {
|
||||
for (const ReportField& field : fields)
|
||||
os << "\t" << field.label << ": " << field.value << "\n";
|
||||
}
|
||||
|
||||
void printReportFieldBlock(llvm::raw_ostream& os, llvm::StringRef title, llvm::ArrayRef<ReportField> fields) {
|
||||
os << "\t" << title << ":\n";
|
||||
for (const ReportField& field : fields)
|
||||
os << "\t " << field.label << ": " << field.value << "\n";
|
||||
}
|
||||
|
||||
void printReportTotalsBlock(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields) {
|
||||
os << "Totals:\n";
|
||||
for (const ReportField& field : fields)
|
||||
os << "\t" << field.label << ": " << field.value << "\n";
|
||||
}
|
||||
|
||||
void printReportPerCoreAndTotalFields(llvm::raw_ostream& os,
|
||||
llvm::ArrayRef<ReportField> perCoreFields,
|
||||
llvm::ArrayRef<ReportField> totalFields) {
|
||||
printReportFieldBlock(os, "Per core", perCoreFields);
|
||||
printReportFieldBlock(os, "Total", totalFields);
|
||||
}
|
||||
|
||||
void printReportEntrySeparator(llvm::raw_ostream& os, bool hasNextEntry) {
|
||||
if (hasNextEntry)
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,48 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
std::fstream openReportFile(const std::string& name);
|
||||
std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRef extension);
|
||||
std::string formatReportMemory(uint64_t bytes);
|
||||
|
||||
struct ReportField {
|
||||
std::string label;
|
||||
std::string value;
|
||||
};
|
||||
|
||||
void printReportFlatFields(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields);
|
||||
void printReportFieldBlock(llvm::raw_ostream& os, llvm::StringRef title, llvm::ArrayRef<ReportField> fields);
|
||||
void printReportTotalsBlock(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields);
|
||||
void printReportPerCoreAndTotalFields(llvm::raw_ostream& os,
|
||||
llvm::ArrayRef<ReportField> perCoreFields,
|
||||
llvm::ArrayRef<ReportField> totalFields);
|
||||
void printReportEntrySeparator(llvm::raw_ostream& os, bool hasNextEntry);
|
||||
|
||||
template <typename EntryTy>
|
||||
int32_t getFirstReportCoreId(const EntryTy& entry) {
|
||||
if (entry.coreIds.empty())
|
||||
return std::numeric_limits<int32_t>::max();
|
||||
return entry.coreIds.front();
|
||||
}
|
||||
|
||||
template <typename EntryRange>
|
||||
void sortReportEntriesByFirstCore(EntryRange& entries) {
|
||||
llvm::stable_sort(entries, [](const auto& lhs, const auto& rhs) {
|
||||
int32_t lhsFirstCore = getFirstReportCoreId(lhs);
|
||||
int32_t rhsFirstCore = getFirstReportCoreId(rhs);
|
||||
if (lhsFirstCore != rhsFirstCore)
|
||||
return lhsFirstCore < rhsFirstCore;
|
||||
return lhs.id < rhs.id;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -15,7 +15,10 @@ add_pim_library(OMPimCompilerOptions
|
||||
|
||||
add_pim_library(OMPimCompilerUtils
|
||||
PimCompilerUtils.cpp
|
||||
PimArtifactWriter.cpp
|
||||
PimCodeGen.cpp
|
||||
PimMemoryLiveness.cpp
|
||||
PimWeightEmitter.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
@@ -26,6 +29,9 @@ add_pim_library(OMPimCompilerUtils
|
||||
OMPimCompilerOptions
|
||||
OMPimCommon
|
||||
OMPimBufferization
|
||||
OMPimMemoryCoalescing
|
||||
OMPimHostConstantFolding
|
||||
OMPimVerification
|
||||
OMPimPasses
|
||||
OMONNXToSpatial
|
||||
OMSpatialToPim
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
OnnxMlirCompilerErrorCodes
|
||||
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
|
||||
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
|
||||
std::error_code errorCode;
|
||||
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
|
||||
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
|
||||
|
||||
SmallPtrSet<Operation*, 16> writtenGlobals;
|
||||
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
|
||||
if (hasWeightAlways(getGlobalOp))
|
||||
return;
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp)
|
||||
return;
|
||||
if (!writtenGlobals.insert(globalOp.getOperation()).second)
|
||||
return;
|
||||
auto initialValue = globalOp.getInitialValue();
|
||||
if (!initialValue)
|
||||
return;
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
|
||||
if (!denseAttr)
|
||||
return;
|
||||
|
||||
MemEntry memEntry = memory.hostMem.getMemEntry({getGlobalOp.getResult(), std::nullopt});
|
||||
ArrayRef<char> rawData = denseAttr.getRawData();
|
||||
char* dst = memoryBuffer.data() + memEntry.address;
|
||||
|
||||
if (denseAttr.isSplat()) {
|
||||
size_t elementSize = rawData.size();
|
||||
assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch");
|
||||
for (size_t offset = 0; offset < memEntry.size; offset += elementSize)
|
||||
std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset));
|
||||
}
|
||||
else {
|
||||
assert(rawData.size() == memEntry.size && "Data size mismatch");
|
||||
std::memcpy(dst, rawData.data(), rawData.size());
|
||||
}
|
||||
});
|
||||
|
||||
memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size());
|
||||
memoryFileStream.close();
|
||||
return CompilerSuccess;
|
||||
}
|
||||
|
||||
OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
|
||||
PimAcceleratorMemory& memory,
|
||||
size_t maxCoreId,
|
||||
json::Object xbarsPerArrayGroup,
|
||||
StringRef outputDirPath) {
|
||||
json::Object configJson;
|
||||
|
||||
configJson["core_cnt"] = maxCoreId + 1;
|
||||
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
|
||||
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
|
||||
|
||||
json::Array inputsAddresses;
|
||||
for (BlockArgument input : funcOp.getArguments())
|
||||
inputsAddresses.push_back(memory.getValueAddress(input));
|
||||
configJson["inputs_addresses"] = std::move(inputsAddresses);
|
||||
|
||||
json::Array outputsAddresses;
|
||||
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
|
||||
for (mlir::Value output : returnOp.getOperands())
|
||||
outputsAddresses.push_back(memory.getValueAddress(output));
|
||||
configJson["outputs_addresses"] = std::move(outputsAddresses);
|
||||
|
||||
auto configPath = (outputDirPath + "/config.json").str();
|
||||
std::error_code errorCode;
|
||||
raw_fd_ostream jsonOS(configPath, errorCode);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening config file: " << errorCode.message() << '\n';
|
||||
return InvalidOutputFileAccess;
|
||||
}
|
||||
jsonOS << json::Value(std::move(configJson)) << '\n';
|
||||
jsonOS.close();
|
||||
return CompilerSuccess;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
|
||||
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
class PimAcceleratorMemory;
|
||||
|
||||
OnnxMlirCompilerErrorCodes writeMemoryBinary(mlir::ModuleOp moduleOp,
|
||||
mlir::func::FuncOp funcOp,
|
||||
PimAcceleratorMemory& memory,
|
||||
llvm::StringRef outputDirPath);
|
||||
OnnxMlirCompilerErrorCodes writeConfigJson(mlir::func::FuncOp funcOp,
|
||||
PimAcceleratorMemory& memory,
|
||||
size_t maxCoreId,
|
||||
llvm::json::Object xbarsPerArrayGroup,
|
||||
llvm::StringRef outputDirPath);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,369 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Endian.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
|
||||
namespace onnx_mlir::pim_binary {
|
||||
|
||||
inline constexpr char kMagic[4] = {'P', 'I', 'M', 'B'};
|
||||
inline constexpr uint32_t kVersion = 1;
|
||||
inline constexpr uint64_t kCountOffset = 8;
|
||||
inline constexpr size_t kHeaderSize = 12;
|
||||
inline constexpr size_t kRecordSize = 20;
|
||||
|
||||
enum class Opcode : uint32_t {
|
||||
nop = 0,
|
||||
sldi = 1,
|
||||
sld = 2,
|
||||
sadd = 3,
|
||||
ssub = 4,
|
||||
smul = 5,
|
||||
saddi = 6,
|
||||
smuli = 7,
|
||||
setbw = 8,
|
||||
mvmul = 9,
|
||||
vvadd = 10,
|
||||
vvsub = 11,
|
||||
vvmul = 12,
|
||||
vvdmul = 13,
|
||||
vvmax = 14,
|
||||
vvsll = 15,
|
||||
vvsra = 16,
|
||||
vavg = 17,
|
||||
vrelu = 18,
|
||||
vtanh = 19,
|
||||
vsigm = 20,
|
||||
vsoftmax = 21,
|
||||
vmv = 22,
|
||||
vrsu = 23,
|
||||
vrsl = 24,
|
||||
ld = 25,
|
||||
st = 26,
|
||||
lldi = 27,
|
||||
lmv = 28,
|
||||
send = 29,
|
||||
recv = 30,
|
||||
wait = 31,
|
||||
sync = 32,
|
||||
};
|
||||
|
||||
struct InstructionRecord {
|
||||
Opcode opcode = Opcode::nop;
|
||||
uint8_t rd = 0;
|
||||
uint8_t r1 = 0;
|
||||
int32_t r2OrImm = 0;
|
||||
int32_t generic1 = 0;
|
||||
int32_t generic2 = 0;
|
||||
int32_t generic3 = 0;
|
||||
uint8_t flags = 0;
|
||||
};
|
||||
|
||||
inline void writeUint32LE(llvm::raw_ostream& os, uint32_t value) {
|
||||
std::array<char, sizeof(uint32_t)> bytes;
|
||||
llvm::support::endian::write32le(bytes.data(), value);
|
||||
os.write(bytes.data(), bytes.size());
|
||||
}
|
||||
|
||||
inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { writeUint32LE(os, static_cast<uint32_t>(value)); }
|
||||
|
||||
inline void writeHeader(llvm::raw_ostream& os) {
|
||||
os.write(kMagic, sizeof(kMagic));
|
||||
writeUint32LE(os, kVersion);
|
||||
writeUint32LE(os, 0);
|
||||
}
|
||||
|
||||
inline void patchInstructionCount(llvm::raw_pwrite_stream& os, uint32_t instructionCount) {
|
||||
std::array<char, sizeof(uint32_t)> bytes;
|
||||
llvm::support::endian::write32le(bytes.data(), instructionCount);
|
||||
os.pwrite(bytes.data(), bytes.size(), kCountOffset);
|
||||
}
|
||||
|
||||
inline void writeInstructionRecord(llvm::raw_ostream& os, const InstructionRecord& record) {
|
||||
os << static_cast<char>(static_cast<uint8_t>(record.opcode));
|
||||
os << static_cast<char>(record.rd);
|
||||
os << static_cast<char>(record.r1);
|
||||
os << static_cast<char>(record.flags);
|
||||
writeInt32LE(os, record.r2OrImm);
|
||||
writeInt32LE(os, record.generic1);
|
||||
writeInt32LE(os, record.generic2);
|
||||
writeInt32LE(os, record.generic3);
|
||||
}
|
||||
|
||||
inline int32_t toI32(int64_t value) { return onnx_mlir::pim::checkedI32OrCrash(value, "binary field"); }
|
||||
|
||||
inline uint8_t toU8(int64_t value) {
|
||||
return onnx_mlir::pim::checkedU8OrCrash(static_cast<uint64_t>(value), "binary field");
|
||||
}
|
||||
|
||||
inline int32_t getOptionalInt(const llvm::json::Object& object, llvm::StringRef key, int32_t defaultValue = 0) {
|
||||
if (std::optional<int64_t> value = object.getInteger(key))
|
||||
return toI32(*value);
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
inline Opcode opcodeFromString(llvm::StringRef opName) {
|
||||
if (opName == "nop")
|
||||
return Opcode::nop;
|
||||
if (opName == "sldi")
|
||||
return Opcode::sldi;
|
||||
if (opName == "sld")
|
||||
return Opcode::sld;
|
||||
if (opName == "sadd")
|
||||
return Opcode::sadd;
|
||||
if (opName == "ssub")
|
||||
return Opcode::ssub;
|
||||
if (opName == "smul")
|
||||
return Opcode::smul;
|
||||
if (opName == "saddi")
|
||||
return Opcode::saddi;
|
||||
if (opName == "smuli")
|
||||
return Opcode::smuli;
|
||||
if (opName == "setbw")
|
||||
return Opcode::setbw;
|
||||
if (opName == "mvmul")
|
||||
return Opcode::mvmul;
|
||||
if (opName == "vvadd")
|
||||
return Opcode::vvadd;
|
||||
if (opName == "vvsub")
|
||||
return Opcode::vvsub;
|
||||
if (opName == "vvmul")
|
||||
return Opcode::vvmul;
|
||||
if (opName == "vvdmul")
|
||||
return Opcode::vvdmul;
|
||||
if (opName == "vvmax")
|
||||
return Opcode::vvmax;
|
||||
if (opName == "vvsll")
|
||||
return Opcode::vvsll;
|
||||
if (opName == "vvsra")
|
||||
return Opcode::vvsra;
|
||||
if (opName == "vavg")
|
||||
return Opcode::vavg;
|
||||
if (opName == "vrelu")
|
||||
return Opcode::vrelu;
|
||||
if (opName == "vtanh")
|
||||
return Opcode::vtanh;
|
||||
if (opName == "vsigm")
|
||||
return Opcode::vsigm;
|
||||
if (opName == "vsoftmax")
|
||||
return Opcode::vsoftmax;
|
||||
if (opName == "vmv")
|
||||
return Opcode::vmv;
|
||||
if (opName == "vrsu")
|
||||
return Opcode::vrsu;
|
||||
if (opName == "vrsl")
|
||||
return Opcode::vrsl;
|
||||
if (opName == "ld")
|
||||
return Opcode::ld;
|
||||
if (opName == "st")
|
||||
return Opcode::st;
|
||||
if (opName == "lldi")
|
||||
return Opcode::lldi;
|
||||
if (opName == "lmv")
|
||||
return Opcode::lmv;
|
||||
if (opName == "send")
|
||||
return Opcode::send;
|
||||
if (opName == "recv")
|
||||
return Opcode::recv;
|
||||
if (opName == "wait")
|
||||
return Opcode::wait;
|
||||
if (opName == "sync")
|
||||
return Opcode::sync;
|
||||
llvm_unreachable("Unsupported PIM binary opcode");
|
||||
}
|
||||
|
||||
inline llvm::StringRef opcodeToString(Opcode opcode) {
|
||||
switch (opcode) {
|
||||
case Opcode::nop: return "nop";
|
||||
case Opcode::sldi: return "sldi";
|
||||
case Opcode::sld: return "sld";
|
||||
case Opcode::sadd: return "sadd";
|
||||
case Opcode::ssub: return "ssub";
|
||||
case Opcode::smul: return "smul";
|
||||
case Opcode::saddi: return "saddi";
|
||||
case Opcode::smuli: return "smuli";
|
||||
case Opcode::setbw: return "setbw";
|
||||
case Opcode::mvmul: return "mvmul";
|
||||
case Opcode::vvadd: return "vvadd";
|
||||
case Opcode::vvsub: return "vvsub";
|
||||
case Opcode::vvmul: return "vvmul";
|
||||
case Opcode::vvdmul: return "vvdmul";
|
||||
case Opcode::vvmax: return "vvmax";
|
||||
case Opcode::vvsll: return "vvsll";
|
||||
case Opcode::vvsra: return "vvsra";
|
||||
case Opcode::vavg: return "vavg";
|
||||
case Opcode::vrelu: return "vrelu";
|
||||
case Opcode::vtanh: return "vtanh";
|
||||
case Opcode::vsigm: return "vsigm";
|
||||
case Opcode::vsoftmax: return "vsoftmax";
|
||||
case Opcode::vmv: return "vmv";
|
||||
case Opcode::vrsu: return "vrsu";
|
||||
case Opcode::vrsl: return "vrsl";
|
||||
case Opcode::ld: return "ld";
|
||||
case Opcode::st: return "st";
|
||||
case Opcode::lldi: return "lldi";
|
||||
case Opcode::lmv: return "lmv";
|
||||
case Opcode::send: return "send";
|
||||
case Opcode::recv: return "recv";
|
||||
case Opcode::wait: return "wait";
|
||||
case Opcode::sync: return "sync";
|
||||
}
|
||||
llvm_unreachable("Unsupported PIM binary opcode");
|
||||
}
|
||||
|
||||
inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruction) {
|
||||
InstructionRecord record;
|
||||
std::optional<llvm::StringRef> opName = instruction.getString("op");
|
||||
assert(opName && "Missing op field in PIM instruction");
|
||||
record.opcode = opcodeFromString(*opName);
|
||||
record.rd = toU8(getOptionalInt(instruction, "rd"));
|
||||
record.r1 = toU8(getOptionalInt(instruction, "rs1"));
|
||||
|
||||
switch (record.opcode) {
|
||||
case Opcode::sldi:
|
||||
case Opcode::saddi:
|
||||
case Opcode::smuli:
|
||||
case Opcode::lldi: record.r2OrImm = getOptionalInt(instruction, "imm"); break;
|
||||
case Opcode::mvmul:
|
||||
record.r2OrImm = getOptionalInt(instruction, "mbiw");
|
||||
record.generic1 = getOptionalInt(instruction, "relu");
|
||||
record.generic2 = getOptionalInt(instruction, "group");
|
||||
break;
|
||||
case Opcode::setbw:
|
||||
record.generic1 = getOptionalInt(instruction, "ibiw");
|
||||
record.generic2 = getOptionalInt(instruction, "obiw");
|
||||
break;
|
||||
case Opcode::send:
|
||||
case Opcode::recv:
|
||||
record.r2OrImm = getOptionalInt(instruction, "core");
|
||||
record.generic3 = getOptionalInt(instruction, "size");
|
||||
break;
|
||||
default: record.r2OrImm = getOptionalInt(instruction, "rs2"); break;
|
||||
}
|
||||
|
||||
if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) {
|
||||
if (auto* offsetValue = instruction.getObject("offset")) {
|
||||
record.generic1 = getOptionalInt(*offsetValue, "offset_select");
|
||||
record.generic2 = getOptionalInt(*offsetValue, "offset_value");
|
||||
}
|
||||
}
|
||||
|
||||
if (instruction.get("len"))
|
||||
record.generic3 = getOptionalInt(instruction, "len");
|
||||
else if (instruction.get("size") && record.opcode != Opcode::send && record.opcode != Opcode::recv)
|
||||
record.generic3 = getOptionalInt(instruction, "size");
|
||||
|
||||
return record;
|
||||
}
|
||||
|
||||
inline llvm::json::Object makeInstructionJson(const InstructionRecord& record) {
|
||||
llvm::json::Object instruction;
|
||||
instruction["op"] = opcodeToString(record.opcode).str();
|
||||
|
||||
auto addOffset = [&](int32_t offsetSelect, int32_t offsetValue) {
|
||||
llvm::json::Object offset;
|
||||
offset["offset_select"] = offsetSelect;
|
||||
offset["offset_value"] = offsetValue;
|
||||
instruction["offset"] = std::move(offset);
|
||||
};
|
||||
|
||||
switch (record.opcode) {
|
||||
case Opcode::sldi:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["imm"] = record.r2OrImm;
|
||||
break;
|
||||
case Opcode::sld:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
addOffset(record.generic1, record.generic2);
|
||||
break;
|
||||
case Opcode::sadd:
|
||||
case Opcode::ssub:
|
||||
case Opcode::smul:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
instruction["rs2"] = record.r2OrImm;
|
||||
break;
|
||||
case Opcode::saddi:
|
||||
case Opcode::smuli:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
instruction["imm"] = record.r2OrImm;
|
||||
break;
|
||||
case Opcode::setbw:
|
||||
instruction["ibiw"] = record.generic1;
|
||||
instruction["obiw"] = record.generic2;
|
||||
break;
|
||||
case Opcode::mvmul:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
instruction["mbiw"] = record.r2OrImm;
|
||||
instruction["relu"] = record.generic1;
|
||||
instruction["group"] = record.generic2;
|
||||
break;
|
||||
case Opcode::vvadd:
|
||||
case Opcode::vvsub:
|
||||
case Opcode::vvmul:
|
||||
case Opcode::vvdmul:
|
||||
case Opcode::vvmax:
|
||||
case Opcode::vvsll:
|
||||
case Opcode::vvsra:
|
||||
case Opcode::vavg:
|
||||
case Opcode::vmv:
|
||||
case Opcode::vrsu:
|
||||
case Opcode::vrsl:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
instruction["rs2"] = record.r2OrImm;
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["len"] = record.generic3;
|
||||
break;
|
||||
case Opcode::vrelu:
|
||||
case Opcode::vtanh:
|
||||
case Opcode::vsigm:
|
||||
case Opcode::vsoftmax:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["len"] = record.generic3;
|
||||
break;
|
||||
case Opcode::ld:
|
||||
case Opcode::st:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["size"] = record.generic3;
|
||||
break;
|
||||
case Opcode::lldi:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["imm"] = record.r2OrImm;
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["len"] = record.generic3;
|
||||
break;
|
||||
case Opcode::lmv:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["rs1"] = static_cast<int64_t>(record.r1);
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["len"] = record.generic3;
|
||||
break;
|
||||
case Opcode::send:
|
||||
case Opcode::recv:
|
||||
instruction["rd"] = static_cast<int64_t>(record.rd);
|
||||
instruction["core"] = record.r2OrImm;
|
||||
addOffset(record.generic1, record.generic2);
|
||||
instruction["size"] = record.generic3;
|
||||
break;
|
||||
case Opcode::wait:
|
||||
case Opcode::sync:
|
||||
case Opcode::nop: break;
|
||||
}
|
||||
|
||||
return instruction;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir::pim_binary
|
||||
+1295
-541
File diff suppressed because it is too large
Load Diff
+160
-20
@@ -1,10 +1,24 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
|
||||
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -14,57 +28,155 @@ struct MemEntry {
|
||||
size_t size;
|
||||
};
|
||||
|
||||
class PimMemory {
|
||||
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
|
||||
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
|
||||
struct PhysicalSlotInfo {
|
||||
size_t id = 0;
|
||||
size_t address = 0;
|
||||
size_t size = 0;
|
||||
};
|
||||
|
||||
struct MemoryPlanArtifacts {
|
||||
std::string textReport;
|
||||
};
|
||||
|
||||
struct MemoryValueKey {
|
||||
mlir::Value value;
|
||||
std::optional<unsigned> lane;
|
||||
|
||||
bool operator==(const MemoryValueKey& other) const { return value == other.value && lane == other.lane; }
|
||||
};
|
||||
|
||||
struct MemoryReportRow {
|
||||
uint64_t numAlloca = 0;
|
||||
uint64_t sizeAlloca = 0;
|
||||
uint64_t numGlobal = 0;
|
||||
uint64_t sizeGlobal = 0;
|
||||
|
||||
bool operator==(const MemoryReportRow& other) const {
|
||||
return numAlloca == other.numAlloca && sizeAlloca == other.sizeAlloca && numGlobal == other.numGlobal
|
||||
&& sizeGlobal == other.sizeGlobal;
|
||||
}
|
||||
};
|
||||
|
||||
enum class MemoryReportKind {
|
||||
None,
|
||||
Alloca,
|
||||
Global,
|
||||
Input
|
||||
};
|
||||
|
||||
struct PendingMemEntry {
|
||||
MemEntry memEntry;
|
||||
MemoryValueKey key;
|
||||
MemoryReportKind reportKind = MemoryReportKind::None;
|
||||
};
|
||||
|
||||
struct MemoryReportEntry {
|
||||
enum class Kind {
|
||||
Core,
|
||||
Batch
|
||||
};
|
||||
|
||||
Kind kind = Kind::Core;
|
||||
uint64_t id = 0;
|
||||
llvm::SmallVector<int32_t, 8> coreIds;
|
||||
MemoryReportRow row;
|
||||
uint64_t totalAllocaCount = 0;
|
||||
uint64_t totalAllocaBytes = 0;
|
||||
};
|
||||
|
||||
class PimMemory {
|
||||
llvm::SmallVector<PendingMemEntry, 32> memEntries;
|
||||
llvm::SmallVector<PhysicalSlotInfo, 32> localPhysicalSlots;
|
||||
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap;
|
||||
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> ownedMemEntriesMap;
|
||||
MemoryReportRow reportRow;
|
||||
MemoryPlanArtifacts livenessArtifacts;
|
||||
|
||||
size_t maxSize = 0; // 0 for unbounded memory
|
||||
size_t startAddress = 0;
|
||||
size_t minAlignment = 4;
|
||||
size_t firstAvailableAddress = 0;
|
||||
size_t nextPhysicalSlotId = 0;
|
||||
|
||||
MemEntry* gatherMemEntry(mlir::Value value);
|
||||
MemEntry* gatherMemEntry(mlir::Value value, std::optional<unsigned> lane = std::nullopt);
|
||||
void allocateGatheredMemory();
|
||||
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
|
||||
void allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry, MemoryReportKind reportKind);
|
||||
PhysicalSlotInfo allocatePhysicalSlot(size_t slotSize, const MemoryValueKey& key);
|
||||
|
||||
public:
|
||||
PimMemory(llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap)
|
||||
PimMemory(llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap)
|
||||
: globalMemEntriesMap(globalMemEntriesMap) {}
|
||||
|
||||
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
|
||||
void allocateCore(mlir::Operation* op);
|
||||
void allocateCore(mlir::Operation* op, std::optional<unsigned> lane = std::nullopt);
|
||||
MemoryReportRow getReportRow() const;
|
||||
const MemoryPlanArtifacts& getLivenessArtifacts() const { return livenessArtifacts; }
|
||||
void remove(mlir::Value val);
|
||||
|
||||
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
|
||||
MemEntry getMemEntry(mlir::Value value) const;
|
||||
MemEntry getMemEntry(const MemoryValueKey& key) const;
|
||||
};
|
||||
|
||||
class PimAcceleratorMemory {
|
||||
public:
|
||||
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> memEntriesMap;
|
||||
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> memEntriesMap;
|
||||
PimMemory hostMem;
|
||||
|
||||
private:
|
||||
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
|
||||
std::fstream fileReport;
|
||||
std::optional<MemoryReportRow> hostReportRow;
|
||||
llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
|
||||
uint64_t totalWeightBytes = 0;
|
||||
mutable llvm::DenseMap<mlir::Value, CompiledIndexExpr> compiledIndexExprs;
|
||||
mutable llvm::DenseMap<mlir::Value, CompiledAddressExpr> compiledAddressExprs;
|
||||
|
||||
public:
|
||||
PimAcceleratorMemory()
|
||||
: hostMem(memEntriesMap) {}
|
||||
: hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {}
|
||||
PimAcceleratorMemory(const llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& initialMemEntries, bool enableReport)
|
||||
: memEntriesMap(initialMemEntries),
|
||||
hostMem(memEntriesMap),
|
||||
fileReport(enableReport ? openReportFile("memory_report") : std::fstream()) {}
|
||||
|
||||
PimMemory& getOrCreateDeviceMem(size_t id);
|
||||
|
||||
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
|
||||
size_t getValueAddress(mlir::Value value,
|
||||
const StaticValueKnowledge& knowledge = {},
|
||||
std::optional<unsigned> lane = std::nullopt) const;
|
||||
llvm::FailureOr<int64_t> getIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
|
||||
void reportHost();
|
||||
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
|
||||
void recordBatchReport(uint64_t batchId,
|
||||
llvm::ArrayRef<int32_t> coreIds,
|
||||
const MemoryReportRow& perCoreRow,
|
||||
uint64_t totalAllocaCount,
|
||||
uint64_t totalAllocaBytes);
|
||||
void setTotalWeightBytes(uint64_t bytes) { totalWeightBytes = bytes; }
|
||||
void flushReport();
|
||||
void clean(mlir::Operation* op);
|
||||
};
|
||||
|
||||
struct CoreEmissionJob {
|
||||
mlir::Operation* coreLikeOp = nullptr;
|
||||
size_t originalCoreId = 0;
|
||||
size_t emittedCoreId = 0;
|
||||
llvm::SmallVector<unsigned, 4> lanes;
|
||||
std::optional<uint64_t> batchReportId;
|
||||
};
|
||||
|
||||
class PimCodeGen {
|
||||
PimAcceleratorMemory& memory;
|
||||
llvm::raw_fd_ostream& coreFileStream;
|
||||
llvm::raw_fd_ostream& coreBinaryStream;
|
||||
llvm::raw_fd_ostream* coreJsonStream;
|
||||
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
|
||||
std::optional<unsigned> batchLane;
|
||||
mutable uint32_t emittedInstructionCount = 0;
|
||||
|
||||
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
||||
return memory.getValueAddress(value, knowledge);
|
||||
return memory.getValueAddress(value, knowledge, batchLane);
|
||||
}
|
||||
size_t remapCoreId(size_t coreId) const;
|
||||
|
||||
static llvm::json::Object createEmptyOffset();
|
||||
void emitInstruction(llvm::json::Object instruction) const;
|
||||
void emitInstruction(const pim_binary::InstructionRecord& instruction) const;
|
||||
|
||||
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
|
||||
void setupRd(size_t rdAddress, size_t rdOffset) const;
|
||||
@@ -83,8 +195,17 @@ class PimCodeGen {
|
||||
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
|
||||
|
||||
public:
|
||||
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
|
||||
: memory(memory), coreFileStream(coreJson) {}
|
||||
PimCodeGen(PimAcceleratorMemory& memory,
|
||||
llvm::raw_fd_ostream& coreBinary,
|
||||
llvm::raw_fd_ostream* coreJson,
|
||||
const llvm::DenseMap<size_t, size_t>& emittedCoreIds)
|
||||
: memory(memory), coreBinaryStream(coreBinary), coreJsonStream(coreJson), emittedCoreIds(emittedCoreIds) {}
|
||||
|
||||
uint32_t getEmittedInstructionCount() const { return emittedInstructionCount; }
|
||||
void setBatchLane(std::optional<unsigned> lane) { batchLane = lane; }
|
||||
llvm::FailureOr<int64_t> indexOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
||||
return memory.getIndexValue(value, knowledge);
|
||||
}
|
||||
|
||||
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
|
||||
@@ -92,6 +213,7 @@ public:
|
||||
|
||||
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
|
||||
|
||||
template <typename MVMTy>
|
||||
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix, const StaticValueKnowledge& knowledge);
|
||||
@@ -106,9 +228,27 @@ public:
|
||||
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGetGlobalOp(mlir::memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
|
||||
};
|
||||
|
||||
OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
|
||||
OnnxMlirCompilerErrorCodes compileToPimCode(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
namespace llvm {
|
||||
|
||||
template <>
|
||||
struct DenseMapInfo<onnx_mlir::MemoryValueKey> {
|
||||
static onnx_mlir::MemoryValueKey getEmptyKey() { return {DenseMapInfo<mlir::Value>::getEmptyKey(), 0}; }
|
||||
|
||||
static onnx_mlir::MemoryValueKey getTombstoneKey() { return {DenseMapInfo<mlir::Value>::getTombstoneKey(), 0}; }
|
||||
|
||||
static unsigned getHashValue(const onnx_mlir::MemoryValueKey& key) {
|
||||
return hash_combine(key.value, key.lane.value_or(std::numeric_limits<unsigned>::max()));
|
||||
}
|
||||
|
||||
static bool isEqual(const onnx_mlir::MemoryValueKey& lhs, const onnx_mlir::MemoryValueKey& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
} // namespace llvm
|
||||
|
||||
@@ -1,16 +1,5 @@
|
||||
/*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
//===------------------------- PimCompilerOptions.cpp --------------------===//
|
||||
//
|
||||
// Copyright 2022 The IBM Research Authors.
|
||||
//
|
||||
// =============================================================================
|
||||
//
|
||||
// Compiler Options for PIM
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
#define DEBUG_TYPE "PimCompilerOptions"
|
||||
@@ -26,36 +15,67 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
|
||||
llvm::cl::init(EmitPimCodegen),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<PimMergeSchedulerType>
|
||||
pimMergeScheduler("pim-merge-scheduler",
|
||||
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
|
||||
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
|
||||
llvm::cl::init(MergeSchedulerPeft),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport(
|
||||
"pim-memory-report",
|
||||
llvm::cl::desc("Emit a human-readable PIM memory planning report"),
|
||||
llvm::cl::values(clEnumValN(PimMemoryReportNone, "none", "Do not emit any PIM memory planning report")),
|
||||
llvm::cl::values(
|
||||
clEnumValN(PimMemoryReportSummary, "summary", "Emit a concise slot reuse report with key offenders")),
|
||||
llvm::cl::values(clEnumValN(PimMemoryReportFull, "full", "Emit the full detailed PIM memory planning report")),
|
||||
llvm::cl::init(PimMemoryReportNone),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
pimOnlyCodegen("pim-only-codegen",
|
||||
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
|
||||
llvm::cl::init(false),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
pimDisableMemoryCoalescing("pim-disable-memory-coalescing",
|
||||
llvm::cl::desc("Skip the PIM memory coalescing pass (developer diagnostic option)"),
|
||||
llvm::cl::init(false),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl",
|
||||
llvm::cl::desc("Use experimental implementation for convolution"),
|
||||
llvm::cl::init(false),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<size_t>
|
||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
|
||||
llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
|
||||
llvm::cl::desc("Also emit per-core JSON instruction files alongside binary .pim files"),
|
||||
llvm::cl::init(false),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<size_t>
|
||||
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(2));
|
||||
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
|
||||
|
||||
llvm::cl::opt<size_t>
|
||||
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
|
||||
|
||||
llvm::cl::opt<long> coresCount("core-count",
|
||||
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
|
||||
llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
|
||||
llvm::cl::init(-1));
|
||||
|
||||
llvm::cl::opt<size_t> dcpCriticalWindowSize(
|
||||
"dcp-critical-window-size",
|
||||
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
||||
"Use 0 to run the legacy full-graph DCP analysis."),
|
||||
llvm::cl::init(1024));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
ignoreConcatError("ignore-concat-error",
|
||||
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
|
||||
llvm::cl::init(false));
|
||||
|
||||
bool hasExplicitPimCoreCount() { return coresCount.getNumOccurrences() != 0; }
|
||||
|
||||
void verifyExplicitPimCoreCount() {
|
||||
if (!hasExplicitPimCoreCount())
|
||||
llvm::report_fatal_error("PIM compilation requires an explicit --core-count=<positive integer>");
|
||||
if (coresCount.getValue() <= 0)
|
||||
llvm::report_fatal_error("PIM compilation requires --core-count to be a positive integer");
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -20,16 +20,32 @@ typedef enum {
|
||||
EmitPimCodegen = 3
|
||||
} PimEmissionTargetType;
|
||||
|
||||
typedef enum {
|
||||
MergeSchedulerPeft = 0,
|
||||
} PimMergeSchedulerType;
|
||||
|
||||
typedef enum {
|
||||
PimMemoryReportNone = 0,
|
||||
PimMemoryReportSummary = 1,
|
||||
PimMemoryReportFull = 2,
|
||||
} PimMemoryReportLevel;
|
||||
|
||||
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
||||
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
|
||||
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
|
||||
extern llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport;
|
||||
|
||||
extern llvm::cl::opt<bool> pimOnlyCodegen;
|
||||
extern llvm::cl::opt<bool> pimDisableMemoryCoalescing;
|
||||
extern llvm::cl::opt<bool> useExperimentalConvImpl;
|
||||
extern llvm::cl::opt<bool> pimEmitJson;
|
||||
|
||||
extern llvm::cl::opt<size_t> crossbarSize;
|
||||
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
||||
extern llvm::cl::opt<long> coresCount;
|
||||
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
|
||||
|
||||
bool hasExplicitPimCoreCount();
|
||||
void verifyExplicitPimCoreCount();
|
||||
|
||||
// This option, by default set to false, will ignore an error when resolving a
|
||||
// specific tiles of the operands of a concat. This specific case is when the
|
||||
|
||||
@@ -17,6 +17,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
PassManager& pm,
|
||||
EmissionTargetType& emissionTarget,
|
||||
std::string outputNameNoExt) {
|
||||
verifyExplicitPimCoreCount();
|
||||
|
||||
if (pimOnlyCodegen) {
|
||||
// Skip all the lowering passes and directly generate code for PIM.
|
||||
@@ -29,31 +30,28 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
if (pimEmissionTarget >= EmitSpatial) {
|
||||
pm.addPass(createONNXToSpatialPass());
|
||||
pm.addPass(createMergeComputeNodesPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Onnx lowered to Spatial"));
|
||||
}
|
||||
|
||||
if (pimEmissionTarget >= EmitPim) {
|
||||
pm.addPass(createSpatialToPimPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
||||
}
|
||||
|
||||
if (pimEmissionTarget >= EmitPimBufferized) {
|
||||
pm.addPass(createPimBufferizationPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Pim bufferized"));
|
||||
}
|
||||
|
||||
if (pimEmissionTarget >= EmitPimCodegen) {
|
||||
pm.addPass(createPimHostConstantFoldingPass());
|
||||
pm.addPass(createMessagePass("Pim host constants folded"));
|
||||
pm.addPass(createPimMaterializeHostConstantsPass());
|
||||
if (!pimDisableMemoryCoalescing)
|
||||
pm.addPass(createPimMemoryCoalescingPass());
|
||||
pm.addPass(createPimVerificationPass());
|
||||
pm.addPass(createMessagePass("Pim verified"));
|
||||
pm.addPass(createEmitPimJsonPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Pim json code emitted"));
|
||||
pm.addPass(createEmitPimCodePass());
|
||||
pm.addPass(createMessagePass("Pim code emitted"));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,733 @@
|
||||
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "Common/Support/CheckedArithmetic.hpp"
|
||||
#include "Common/Support/ReportUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimMemoryLiveness.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
using namespace onnx_mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
static std::optional<unsigned> getLaneForMemoryValue(mlir::Value value, std::optional<unsigned> lane) {
|
||||
if (!lane)
|
||||
return std::nullopt;
|
||||
auto allocOp = value.getDefiningOp<memref::AllocOp>();
|
||||
if (!allocOp || !allocOp->getParentOfType<pim::PimCoreBatchOp>())
|
||||
return std::nullopt;
|
||||
return lane;
|
||||
}
|
||||
|
||||
static MemoryValueKey getMemoryValueKey(mlir::Value value, std::optional<unsigned> lane = std::nullopt) {
|
||||
return {value, getLaneForMemoryValue(value, lane)};
|
||||
}
|
||||
|
||||
struct MemoryTouchInterval {
|
||||
uint64_t start = 0;
|
||||
uint64_t end = 0;
|
||||
Operation* startOp = nullptr;
|
||||
Operation* endOp = nullptr;
|
||||
Operation* firstTouchOp = nullptr;
|
||||
Operation* lastTouchOp = nullptr;
|
||||
uint64_t firstTouchPosition = 0;
|
||||
uint64_t lastTouchPosition = 0;
|
||||
bool hasRuntimeUse = false;
|
||||
bool startUsedAllocFallback = false;
|
||||
bool endUsedFallback = false;
|
||||
bool escapesLoop = false;
|
||||
std::string fallbackReason;
|
||||
llvm::SmallVector<std::string, 8> aliasesFollowed;
|
||||
};
|
||||
|
||||
struct OperationOrdering {
|
||||
llvm::DenseMap<Operation*, uint64_t> position;
|
||||
llvm::DenseMap<Operation*, uint64_t> subtreeEnd;
|
||||
uint64_t nextPosition = 0;
|
||||
};
|
||||
|
||||
static std::string printValueToString(mlir::Value value) {
|
||||
std::string text;
|
||||
llvm::raw_string_ostream os(text);
|
||||
value.print(os);
|
||||
os.flush();
|
||||
return text;
|
||||
}
|
||||
|
||||
static std::string printOperationToString(Operation* op) {
|
||||
if (!op)
|
||||
return "<none>";
|
||||
std::string text;
|
||||
llvm::raw_string_ostream os(text);
|
||||
op->print(os);
|
||||
os.flush();
|
||||
return text;
|
||||
}
|
||||
|
||||
static std::string printLocationToString(Location loc) {
|
||||
std::string text;
|
||||
llvm::raw_string_ostream os(text);
|
||||
loc.print(os);
|
||||
os.flush();
|
||||
return text;
|
||||
}
|
||||
|
||||
static std::string collapseWhitespace(StringRef text) {
|
||||
std::string out;
|
||||
out.reserve(text.size());
|
||||
bool lastWasSpace = false;
|
||||
for (char c : text) {
|
||||
bool isSpace = c == ' ' || c == '\n' || c == '\t' || c == '\r';
|
||||
if (isSpace) {
|
||||
if (!lastWasSpace && !out.empty())
|
||||
out.push_back(' ');
|
||||
lastWasSpace = true;
|
||||
continue;
|
||||
}
|
||||
out.push_back(c);
|
||||
lastWasSpace = false;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
static std::string abbreviate(StringRef text, size_t maxLen) {
|
||||
if (text.size() <= maxLen)
|
||||
return text.str();
|
||||
return (text.take_front(maxLen - 3) + "...").str();
|
||||
}
|
||||
|
||||
static std::string summarizeValue(mlir::Value value, size_t maxLen = 72) {
|
||||
return abbreviate(collapseWhitespace(printValueToString(value)), maxLen);
|
||||
}
|
||||
|
||||
static std::string summarizeOperation(Operation* op, size_t maxLen = 96) {
|
||||
if (!op)
|
||||
return "<none>";
|
||||
std::string prefix = op->getName().getStringRef().str();
|
||||
std::string full = collapseWhitespace(printOperationToString(op));
|
||||
if (full == prefix)
|
||||
return prefix;
|
||||
return abbreviate(prefix + " :: " + full, maxLen);
|
||||
}
|
||||
|
||||
static std::string summarizeLocation(Location loc, size_t maxLen = 88) {
|
||||
return abbreviate(collapseWhitespace(printLocationToString(loc)), maxLen);
|
||||
}
|
||||
|
||||
static void assignOperationOrdering(Operation* op, OperationOrdering& ordering) {
|
||||
uint64_t position = ordering.nextPosition++;
|
||||
ordering.position[op] = position;
|
||||
uint64_t end = position;
|
||||
for (Region& region : op->getRegions())
|
||||
for (Block& block : region)
|
||||
for (Operation& nestedOp : block) {
|
||||
assignOperationOrdering(&nestedOp, ordering);
|
||||
end = std::max(end, ordering.subtreeEnd.lookup(&nestedOp));
|
||||
}
|
||||
ordering.subtreeEnd[op] = end;
|
||||
}
|
||||
|
||||
static OperationOrdering buildOperationOrdering(Operation* coreLikeOp) {
|
||||
OperationOrdering ordering;
|
||||
if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty())
|
||||
return ordering;
|
||||
|
||||
for (Operation& op : coreLikeOp->getRegion(0).front())
|
||||
assignOperationOrdering(&op, ordering);
|
||||
return ordering;
|
||||
}
|
||||
|
||||
static bool isSupportedAliasOp(Operation* op) {
|
||||
return isa<memref::SubViewOp, memref::CastOp, memref::CollapseShapeOp, memref::ExpandShapeOp>(op);
|
||||
}
|
||||
|
||||
static bool isRuntimeMemoryTouchOp(Operation* op) {
|
||||
return isa<pim::PimMemCopyHostToDevOp,
|
||||
pim::PimMemCopyDevToHostOp,
|
||||
pim::PimMemCopyOp,
|
||||
pim::PimReceiveOp,
|
||||
pim::PimSendOp,
|
||||
pim::PimConcatOp,
|
||||
pim::PimVMMOp,
|
||||
pim::PimTransposeOp,
|
||||
pim::PimVVAddOp,
|
||||
pim::PimVVSubOp,
|
||||
pim::PimVVMulOp,
|
||||
pim::PimVVMaxOp,
|
||||
pim::PimVVDMulOp,
|
||||
pim::PimVAvgOp,
|
||||
pim::PimVReluOp,
|
||||
pim::PimVTanhOp,
|
||||
pim::PimVSigmOp,
|
||||
pim::PimVSoftmaxOp>(op);
|
||||
}
|
||||
|
||||
static bool isIgnoredLivenessUser(Operation* op) {
|
||||
return isSupportedAliasOp(op) || isa<scf::ForOp, scf::YieldOp, memref::DeallocOp>(op) || isCoreStaticAddressOp(op);
|
||||
}
|
||||
|
||||
static bool isWithin(mlir::Value value, Region* region) {
|
||||
if (!region)
|
||||
return false;
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||
return blockArg.getOwner()->getParent() == region;
|
||||
if (Operation* definingOp = value.getDefiningOp())
|
||||
return definingOp->getParentRegion() == region || region->isAncestor(definingOp->getParentRegion());
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool isNestedAllocation(Operation* coreLikeOp, memref::AllocOp allocOp) {
|
||||
if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty())
|
||||
return false;
|
||||
return allocOp->getBlock() != &coreLikeOp->getRegion(0).front();
|
||||
}
|
||||
|
||||
static void addFallbackReason(std::string& reason, StringRef newReason) {
|
||||
if (newReason.empty())
|
||||
return;
|
||||
if (!reason.empty())
|
||||
reason += "; ";
|
||||
reason += newReason.str();
|
||||
}
|
||||
|
||||
static void appendAliasDescription(llvm::SmallVectorImpl<std::string>& aliases, mlir::Value value) {
|
||||
std::string text = printValueToString(value);
|
||||
if (!llvm::is_contained(aliases, text))
|
||||
aliases.push_back(std::move(text));
|
||||
}
|
||||
|
||||
struct OrderedTouchRange {
|
||||
uint64_t start = 0;
|
||||
uint64_t end = 0;
|
||||
Operation* startOp = nullptr;
|
||||
Operation* endOp = nullptr;
|
||||
bool escapedLoop = false;
|
||||
};
|
||||
|
||||
static OrderedTouchRange
|
||||
getEffectiveTouchRange(mlir::Value definingValue, Operation* user, const OperationOrdering& ordering) {
|
||||
OrderedTouchRange range {ordering.position.lookup(user), ordering.position.lookup(user), user, user, false};
|
||||
for (Operation* current = user; current; current = current->getParentOp()) {
|
||||
auto forOp = dyn_cast<scf::ForOp>(current);
|
||||
if (!forOp || isWithin(definingValue, &forOp.getRegion()))
|
||||
continue;
|
||||
range.start = std::min(range.start, ordering.position.lookup(forOp));
|
||||
range.end = std::max(range.end, ordering.subtreeEnd.lookup(forOp));
|
||||
range.startOp = forOp;
|
||||
range.endOp = forOp;
|
||||
range.escapedLoop = true;
|
||||
}
|
||||
return range;
|
||||
}
|
||||
|
||||
static MemoryTouchInterval
|
||||
computeMemoryTouchInterval(memref::AllocOp allocOp, const OperationOrdering& ordering, uint64_t fallbackEnd) {
|
||||
MemoryTouchInterval interval;
|
||||
interval.start = ordering.position.lookup(allocOp);
|
||||
interval.end = interval.start;
|
||||
interval.startOp = allocOp;
|
||||
interval.endOp = allocOp;
|
||||
|
||||
SmallPtrSet<mlir::Value, 16> visitedValues;
|
||||
SmallPtrSet<Operation*, 32> visitedUsers;
|
||||
SmallVector<mlir::Value> pendingValues;
|
||||
pendingValues.push_back(allocOp.getResult());
|
||||
auto parentLoop = allocOp->getParentOfType<scf::ForOp>();
|
||||
|
||||
while (!pendingValues.empty()) {
|
||||
mlir::Value value = pendingValues.pop_back_val();
|
||||
if (!visitedValues.insert(value).second)
|
||||
continue;
|
||||
|
||||
for (Operation* user : value.getUsers()) {
|
||||
if (!visitedUsers.insert(user).second)
|
||||
continue;
|
||||
|
||||
if (isSupportedAliasOp(user)) {
|
||||
for (mlir::Value result : user->getResults()) {
|
||||
pendingValues.push_back(result);
|
||||
appendAliasDescription(interval.aliasesFollowed, result);
|
||||
}
|
||||
}
|
||||
|
||||
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
|
||||
for (OpResult result : user->getResults()) {
|
||||
OpOperand* tiedOperand = dpsOp.getTiedOpOperand(result);
|
||||
if (!tiedOperand || tiedOperand->get() != value)
|
||||
continue;
|
||||
pendingValues.push_back(result);
|
||||
appendAliasDescription(interval.aliasesFollowed, result);
|
||||
}
|
||||
}
|
||||
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
|
||||
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) {
|
||||
if (initArg != value)
|
||||
continue;
|
||||
pendingValues.push_back(forOp.getRegionIterArgs()[index]);
|
||||
pendingValues.push_back(forOp.getResult(index));
|
||||
appendAliasDescription(interval.aliasesFollowed, forOp.getRegionIterArgs()[index]);
|
||||
appendAliasDescription(interval.aliasesFollowed, forOp.getResult(index));
|
||||
if (parentLoop && forOp != parentLoop)
|
||||
interval.escapesLoop = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
|
||||
auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
|
||||
if (!forOp) {
|
||||
addFallbackReason(interval.fallbackReason, "yield without scf.for parent");
|
||||
}
|
||||
else {
|
||||
for (auto [index, operand] : llvm::enumerate(yieldOp.getOperands())) {
|
||||
if (operand != value)
|
||||
continue;
|
||||
pendingValues.push_back(forOp.getResult(index));
|
||||
appendAliasDescription(interval.aliasesFollowed, forOp.getResult(index));
|
||||
if (parentLoop && forOp == parentLoop)
|
||||
interval.escapesLoop = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isRuntimeMemoryTouchOp(user)) {
|
||||
uint64_t touchPosition = ordering.position.lookup(user);
|
||||
if (!interval.hasRuntimeUse || touchPosition < interval.firstTouchPosition) {
|
||||
interval.firstTouchPosition = touchPosition;
|
||||
interval.firstTouchOp = user;
|
||||
}
|
||||
if (!interval.hasRuntimeUse || touchPosition > interval.lastTouchPosition) {
|
||||
interval.lastTouchPosition = touchPosition;
|
||||
interval.lastTouchOp = user;
|
||||
}
|
||||
|
||||
OrderedTouchRange range = getEffectiveTouchRange(allocOp.getResult(), user, ordering);
|
||||
interval.escapesLoop |= range.escapedLoop;
|
||||
if (!interval.hasRuntimeUse) {
|
||||
interval.start = range.start;
|
||||
interval.end = range.end;
|
||||
interval.startOp = range.startOp;
|
||||
interval.endOp = range.endOp;
|
||||
interval.hasRuntimeUse = true;
|
||||
}
|
||||
else {
|
||||
if (range.start < interval.start) {
|
||||
interval.start = range.start;
|
||||
interval.startOp = range.startOp;
|
||||
}
|
||||
if (range.end > interval.end) {
|
||||
interval.end = range.end;
|
||||
interval.endOp = range.endOp;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isIgnoredLivenessUser(user))
|
||||
continue;
|
||||
|
||||
addFallbackReason(interval.fallbackReason, "unhandled user op");
|
||||
interval.endUsedFallback = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!interval.hasRuntimeUse) {
|
||||
interval.startUsedAllocFallback = true;
|
||||
interval.endUsedFallback = true;
|
||||
interval.start = ordering.position.lookup(allocOp);
|
||||
interval.end = fallbackEnd;
|
||||
interval.startOp = allocOp;
|
||||
interval.endOp = allocOp->getParentOp();
|
||||
interval.firstTouchPosition = interval.start;
|
||||
interval.lastTouchPosition = interval.end;
|
||||
addFallbackReason(interval.fallbackReason, "no runtime memory touch");
|
||||
return interval;
|
||||
}
|
||||
|
||||
if (interval.endUsedFallback) {
|
||||
interval.end = std::max(interval.end, fallbackEnd);
|
||||
interval.endOp = allocOp->getParentOp();
|
||||
}
|
||||
|
||||
return interval;
|
||||
}
|
||||
|
||||
static FailureOr<size_t> getAllocSizeBytes(memref::AllocOp allocOp) {
|
||||
auto type = dyn_cast<ShapedType>(allocOp.getType());
|
||||
if (!type)
|
||||
return failure();
|
||||
auto checkedBytes = pim::getCheckedShapedTypeSizeInBytes(type, allocOp, "memory allocation byte size");
|
||||
if (failed(checkedBytes))
|
||||
return failure();
|
||||
return pim::checkedSize(*checkedBytes, allocOp, "memory allocation byte size");
|
||||
}
|
||||
|
||||
static bool intervalsOverlap(const LocalAllocInterval& lhs, const LocalAllocInterval& rhs) {
|
||||
return !(lhs.end < rhs.start || rhs.end < lhs.start);
|
||||
}
|
||||
|
||||
static uint64_t getSlotLogicalBytes(const PlannedPhysicalSlot& slot, ArrayRef<LocalAllocInterval> intervals) {
|
||||
uint64_t slotLogicalBytes = 0;
|
||||
for (size_t intervalIndex : slot.intervalIndices)
|
||||
slotLogicalBytes += intervals[intervalIndex].size;
|
||||
return slotLogicalBytes;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SmallVector<LocalAllocInterval, 0> onnx_mlir::buildLocalAllocIntervals(Operation* coreLikeOp,
|
||||
std::optional<unsigned> lane) {
|
||||
SmallVector<LocalAllocInterval, 0> intervals;
|
||||
OperationOrdering ordering = buildOperationOrdering(coreLikeOp);
|
||||
if (ordering.position.empty())
|
||||
return intervals;
|
||||
|
||||
uint64_t fallbackEnd = ordering.nextPosition == 0 ? 0 : ordering.nextPosition - 1;
|
||||
size_t nextIntervalId = 0;
|
||||
coreLikeOp->walk([&](memref::AllocOp allocOp) {
|
||||
auto checkedSize = getAllocSizeBytes(allocOp);
|
||||
if (failed(checkedSize)) {
|
||||
llvm::errs() << "Failed to compute local allocation size for value: ";
|
||||
allocOp.getResult().print(llvm::errs());
|
||||
llvm::errs() << "\n";
|
||||
llvm_unreachable("Failed to compute local allocation size");
|
||||
}
|
||||
|
||||
MemoryTouchInterval touchInterval = computeMemoryTouchInterval(allocOp, ordering, fallbackEnd);
|
||||
LocalAllocInterval interval;
|
||||
interval.id = nextIntervalId++;
|
||||
interval.alloc = allocOp;
|
||||
interval.key = getMemoryValueKey(allocOp.getResult(), lane);
|
||||
interval.start = touchInterval.start;
|
||||
interval.end = touchInterval.end;
|
||||
interval.size = *checkedSize;
|
||||
interval.startOp = touchInterval.startOp;
|
||||
interval.endOp = touchInterval.endOp;
|
||||
interval.firstTouchOp = touchInterval.firstTouchOp;
|
||||
interval.lastTouchOp = touchInterval.lastTouchOp;
|
||||
interval.firstTouchPosition = touchInterval.firstTouchPosition;
|
||||
interval.lastTouchPosition = touchInterval.lastTouchPosition;
|
||||
interval.startUsedAllocFallback = touchInterval.startUsedAllocFallback;
|
||||
interval.endUsedFallback = touchInterval.endUsedFallback;
|
||||
interval.hasRuntimeUse = touchInterval.hasRuntimeUse;
|
||||
interval.insideNestedRegion = isNestedAllocation(coreLikeOp, allocOp);
|
||||
interval.escapesLoop = touchInterval.escapesLoop;
|
||||
interval.fallbackReason = std::move(touchInterval.fallbackReason);
|
||||
interval.aliasesFollowed = std::move(touchInterval.aliasesFollowed);
|
||||
intervals.push_back(std::move(interval));
|
||||
});
|
||||
|
||||
return intervals;
|
||||
}
|
||||
|
||||
SmallVector<PlannedPhysicalSlot, 0> onnx_mlir::planPhysicalSlots(MutableArrayRef<LocalAllocInterval> intervals) {
|
||||
SmallVector<PlannedPhysicalSlot, 0> slots;
|
||||
SmallVector<size_t> intervalOrder(intervals.size());
|
||||
std::iota(intervalOrder.begin(), intervalOrder.end(), 0);
|
||||
llvm::stable_sort(intervalOrder, [&](size_t lhsIndex, size_t rhsIndex) {
|
||||
const LocalAllocInterval& lhs = intervals[lhsIndex];
|
||||
const LocalAllocInterval& rhs = intervals[rhsIndex];
|
||||
if (lhs.size != rhs.size)
|
||||
return lhs.size > rhs.size;
|
||||
if (lhs.start != rhs.start)
|
||||
return lhs.start < rhs.start;
|
||||
if (lhs.end != rhs.end)
|
||||
return lhs.end < rhs.end;
|
||||
return lhs.id < rhs.id;
|
||||
});
|
||||
|
||||
for (size_t intervalIndex : intervalOrder) {
|
||||
LocalAllocInterval& interval = intervals[intervalIndex];
|
||||
PlannedPhysicalSlot* bestSlot = nullptr;
|
||||
auto bestKey = std::tuple<size_t, size_t, size_t, size_t>(std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<size_t>::max(),
|
||||
std::numeric_limits<size_t>::max());
|
||||
|
||||
for (size_t slotIndex = 0; slotIndex < slots.size(); ++slotIndex) {
|
||||
PlannedPhysicalSlot& slot = slots[slotIndex];
|
||||
bool compatible = true;
|
||||
for (size_t otherIndex : slot.intervalIndices) {
|
||||
if (intervalsOverlap(interval, intervals[otherIndex])) {
|
||||
compatible = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!compatible)
|
||||
continue;
|
||||
|
||||
size_t resultingSize = std::max(slot.requiredSize, interval.size);
|
||||
size_t growth = resultingSize - slot.requiredSize;
|
||||
auto candidateKey =
|
||||
std::tuple<size_t, size_t, size_t, size_t>(growth, resultingSize, slot.intervalIndices.size(), slot.id);
|
||||
if (candidateKey < bestKey) {
|
||||
bestKey = candidateKey;
|
||||
bestSlot = &slot;
|
||||
}
|
||||
}
|
||||
|
||||
if (!bestSlot) {
|
||||
slots.push_back({slots.size(), interval.size, interval.size, 0, {intervalIndex}});
|
||||
interval.slotPlanIndex = slots.size() - 1;
|
||||
interval.physicalSlotId = slots.back().id;
|
||||
interval.physicalSlotSize = slots.back().requiredSize;
|
||||
continue;
|
||||
}
|
||||
|
||||
bestSlot->requiredSize = std::max(bestSlot->requiredSize, interval.size);
|
||||
bestSlot->size = bestSlot->requiredSize;
|
||||
bestSlot->intervalIndices.push_back(intervalIndex);
|
||||
interval.slotPlanIndex = static_cast<size_t>(bestSlot - slots.data());
|
||||
interval.physicalSlotId = bestSlot->id;
|
||||
interval.physicalSlotSize = bestSlot->requiredSize;
|
||||
}
|
||||
|
||||
return slots;
|
||||
}
|
||||
|
||||
MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation* coreLikeOp,
|
||||
std::optional<unsigned> lane,
|
||||
ArrayRef<LocalAllocInterval> intervals,
|
||||
ArrayRef<PlannedPhysicalSlot> slots,
|
||||
size_t addressLimit,
|
||||
PimMemoryReportLevel reportLevel) {
|
||||
MemoryPlanArtifacts artifacts;
|
||||
|
||||
uint64_t totalLogicalBytes = 0;
|
||||
uint64_t totalPhysicalBytes = 0;
|
||||
uint64_t fallbackIntervals = 0;
|
||||
uint64_t noRuntimeTouchIntervals = 0;
|
||||
uint64_t reusedAllocations = 0;
|
||||
uint64_t nestedIntervals = 0;
|
||||
uint64_t loopEscapingIntervals = 0;
|
||||
size_t largestLogicalAllocation = 0;
|
||||
size_t largestPhysicalSlot = 0;
|
||||
size_t maximumAssignedAddress = 0;
|
||||
|
||||
for (const LocalAllocInterval& interval : intervals) {
|
||||
totalLogicalBytes += interval.size;
|
||||
largestLogicalAllocation = std::max(largestLogicalAllocation, interval.size);
|
||||
maximumAssignedAddress = std::max(maximumAssignedAddress, interval.assignedAddress + interval.physicalSlotSize);
|
||||
if (interval.startUsedAllocFallback || interval.endUsedFallback)
|
||||
++fallbackIntervals;
|
||||
if (!interval.hasRuntimeUse)
|
||||
++noRuntimeTouchIntervals;
|
||||
if (interval.insideNestedRegion)
|
||||
++nestedIntervals;
|
||||
if (interval.escapesLoop)
|
||||
++loopEscapingIntervals;
|
||||
}
|
||||
for (const PlannedPhysicalSlot& slot : slots) {
|
||||
totalPhysicalBytes += slot.size;
|
||||
largestPhysicalSlot = std::max(largestPhysicalSlot, slot.size);
|
||||
if (slot.intervalIndices.size() > 1)
|
||||
reusedAllocations += slot.intervalIndices.size() - 1;
|
||||
}
|
||||
|
||||
uint64_t savedBytes = totalLogicalBytes >= totalPhysicalBytes ? totalLogicalBytes - totalPhysicalBytes : 0;
|
||||
double savedPercent =
|
||||
totalLogicalBytes == 0 ? 0.0 : 100.0 * static_cast<double>(savedBytes) / static_cast<double>(totalLogicalBytes);
|
||||
|
||||
raw_string_ostream os(artifacts.textReport);
|
||||
os << "=== PIM Memory Liveness Report ===\n";
|
||||
os << "Op: " << coreLikeOp->getName() << "\n";
|
||||
if (lane)
|
||||
os << "Lane: " << *lane << "\n";
|
||||
os << "Summary:\n";
|
||||
os << " logical allocation bytes: " << formatReportMemory(totalLogicalBytes) << " (" << totalLogicalBytes << ")\n";
|
||||
os << " physical allocation bytes: " << formatReportMemory(totalPhysicalBytes) << " (" << totalPhysicalBytes
|
||||
<< ")\n";
|
||||
os << " saved bytes: " << formatReportMemory(savedBytes) << " (" << savedBytes << ")\n";
|
||||
os << " saved percent: " << format("%.2f%%", savedPercent) << "\n";
|
||||
os << " intervals: " << intervals.size() << "\n";
|
||||
os << " physical slots: " << slots.size() << "\n";
|
||||
os << " reused allocations: " << reusedAllocations << "\n";
|
||||
os << " fallback intervals: " << fallbackIntervals << "\n";
|
||||
os << " intervals with no runtime memory touch: " << noRuntimeTouchIntervals << "\n";
|
||||
os << " nested allocations: " << nestedIntervals << "\n";
|
||||
os << " loop-escaping allocations: " << loopEscapingIntervals << "\n";
|
||||
os << " largest logical allocation: " << largestLogicalAllocation << "\n";
|
||||
os << " largest physical slot: " << largestPhysicalSlot << "\n";
|
||||
os << " address limit: " << addressLimit << "\n";
|
||||
os << " peak physical memory: " << formatReportMemory(maximumAssignedAddress) << " (" << maximumAssignedAddress
|
||||
<< ")\n";
|
||||
os << " maximum assigned address: " << maximumAssignedAddress << "\n";
|
||||
|
||||
os << "\nHow To Read:\n";
|
||||
os << " `summary` only shows the strongest reuse cases and the worst offenders.\n";
|
||||
os << " Use `--pim-memory-report=full` when you need the complete slot-by-slot and interval-by-interval dump.\n";
|
||||
os << " Large single-use slots, fallback intervals, and nested single-use allocations are the best places\n";
|
||||
os << " to inspect if allocations should be moved, sunk, or made easier to coalesce earlier in the pipeline.\n";
|
||||
|
||||
SmallVector<const PlannedPhysicalSlot*> reusedSlots;
|
||||
SmallVector<const PlannedPhysicalSlot*> singleUseSlots;
|
||||
for (const PlannedPhysicalSlot& slot : slots)
|
||||
if (slot.intervalIndices.size() > 1)
|
||||
reusedSlots.push_back(&slot);
|
||||
else
|
||||
singleUseSlots.push_back(&slot);
|
||||
|
||||
llvm::stable_sort(reusedSlots, [&](const PlannedPhysicalSlot* lhs, const PlannedPhysicalSlot* rhs) {
|
||||
uint64_t lhsLogicalBytes = getSlotLogicalBytes(*lhs, intervals);
|
||||
uint64_t rhsLogicalBytes = getSlotLogicalBytes(*rhs, intervals);
|
||||
if (lhs->intervalIndices.size() != rhs->intervalIndices.size())
|
||||
return lhs->intervalIndices.size() > rhs->intervalIndices.size();
|
||||
if (lhsLogicalBytes != rhsLogicalBytes)
|
||||
return lhsLogicalBytes > rhsLogicalBytes;
|
||||
if (lhs->size != rhs->size)
|
||||
return lhs->size > rhs->size;
|
||||
return lhs->id < rhs->id;
|
||||
});
|
||||
llvm::stable_sort(singleUseSlots, [&](const PlannedPhysicalSlot* lhs, const PlannedPhysicalSlot* rhs) {
|
||||
if (lhs->size != rhs->size)
|
||||
return lhs->size > rhs->size;
|
||||
return lhs->id < rhs->id;
|
||||
});
|
||||
|
||||
constexpr size_t kSummaryReuseLimit = 6;
|
||||
constexpr size_t kSummaryOffenderLimit = 10;
|
||||
|
||||
os << "\nBest Reuse:\n";
|
||||
if (reusedSlots.empty()) {
|
||||
os << " no slots were shared by multiple intervals\n";
|
||||
}
|
||||
else {
|
||||
for (const PlannedPhysicalSlot* slot : ArrayRef(reusedSlots).take_front(kSummaryReuseLimit)) {
|
||||
uint64_t slotLogicalBytes = getSlotLogicalBytes(*slot, intervals);
|
||||
os << " slot #" << slot->id << " addr=" << slot->address << " size=" << formatReportMemory(slot->size)
|
||||
<< " intervals=" << slot->intervalIndices.size() << " logical_sum=" << formatReportMemory(slotLogicalBytes)
|
||||
<< "\n";
|
||||
for (size_t intervalIndex : slot->intervalIndices) {
|
||||
const LocalAllocInterval& interval = intervals[intervalIndex];
|
||||
os << " #" << interval.id << " [" << interval.start << "," << interval.end << "]"
|
||||
<< " logical=" << formatReportMemory(interval.size)
|
||||
<< " first=" << summarizeOperation(interval.firstTouchOp, 40)
|
||||
<< " last=" << summarizeOperation(interval.lastTouchOp, 40) << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
os << "\nTop Offenders:\n";
|
||||
bool printedAttention = false;
|
||||
for (const PlannedPhysicalSlot* slot : ArrayRef(singleUseSlots).take_front(kSummaryOffenderLimit)) {
|
||||
const LocalAllocInterval& interval = intervals[slot->intervalIndices.front()];
|
||||
printedAttention = true;
|
||||
os << " slot #" << slot->id << " is single-use"
|
||||
<< " size=" << formatReportMemory(slot->size) << " interval=#" << interval.id
|
||||
<< " value=" << summarizeValue(interval.key.value, 56) << "\n";
|
||||
os << " first=" << summarizeOperation(interval.firstTouchOp, 40)
|
||||
<< " last=" << summarizeOperation(interval.lastTouchOp, 40)
|
||||
<< " nested=" << (interval.insideNestedRegion ? "yes" : "no")
|
||||
<< " escapes_loop=" << (interval.escapesLoop ? "yes" : "no") << "\n";
|
||||
}
|
||||
size_t fallbackPrinted = 0;
|
||||
for (const LocalAllocInterval& interval : intervals) {
|
||||
if (!(interval.startUsedAllocFallback || interval.endUsedFallback) || fallbackPrinted >= kSummaryOffenderLimit)
|
||||
continue;
|
||||
printedAttention = true;
|
||||
++fallbackPrinted;
|
||||
os << " fallback interval #" << interval.id << " size=" << formatReportMemory(interval.size)
|
||||
<< " value=" << summarizeValue(interval.key.value, 56) << "\n";
|
||||
os << " reason: " << (interval.fallbackReason.empty() ? "<none>" : interval.fallbackReason) << "\n";
|
||||
}
|
||||
size_t nestedPrinted = 0;
|
||||
for (const LocalAllocInterval& interval : intervals) {
|
||||
if (nestedPrinted >= kSummaryOffenderLimit)
|
||||
break;
|
||||
if (!(interval.insideNestedRegion && slots[interval.slotPlanIndex].intervalIndices.size() == 1))
|
||||
continue;
|
||||
printedAttention = true;
|
||||
++nestedPrinted;
|
||||
os << " nested single-use interval #" << interval.id << " slot #" << interval.physicalSlotId
|
||||
<< " size=" << formatReportMemory(interval.size) << " value=" << summarizeValue(interval.key.value, 56)
|
||||
<< "\n";
|
||||
os << " hint: move or sink this alloc inside the nested region if the IR allows it.\n";
|
||||
}
|
||||
if (!printedAttention)
|
||||
os << " no obvious blockers detected in this core\n";
|
||||
|
||||
if (reportLevel == PimMemoryReportFull) {
|
||||
os << "\nSlot Reuse:\n";
|
||||
for (const PlannedPhysicalSlot& slot : slots) {
|
||||
uint64_t slotLogicalBytes = getSlotLogicalBytes(slot, intervals);
|
||||
os << " slot #" << slot.id << " addr=" << slot.address << " size=" << formatReportMemory(slot.size) << " ("
|
||||
<< slot.size << ")"
|
||||
<< " intervals=" << slot.intervalIndices.size() << " logical_sum=" << formatReportMemory(slotLogicalBytes)
|
||||
<< "\n";
|
||||
for (size_t intervalIndex : slot.intervalIndices) {
|
||||
const LocalAllocInterval& interval = intervals[intervalIndex];
|
||||
mlir::Value allocValue = interval.key.value;
|
||||
os << " [" << interval.start << "," << interval.end << "]"
|
||||
<< " #" << interval.id << " logical=" << formatReportMemory(interval.size)
|
||||
<< " nested=" << (interval.insideNestedRegion ? "yes" : "no")
|
||||
<< " escapes_loop=" << (interval.escapesLoop ? "yes" : "no")
|
||||
<< " first=" << summarizeOperation(interval.firstTouchOp, 48)
|
||||
<< " last=" << summarizeOperation(interval.lastTouchOp, 48) << "\n";
|
||||
os << " value=" << summarizeValue(allocValue) << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (reportLevel == PimMemoryReportFull) {
|
||||
os << "\nInterval Details:\n";
|
||||
for (const LocalAllocInterval& interval : intervals) {
|
||||
const PlannedPhysicalSlot& slot = slots[interval.slotPlanIndex];
|
||||
mlir::Value allocValue = interval.key.value;
|
||||
Operation* definingOp = allocValue.getDefiningOp();
|
||||
os << " #" << interval.id << " slot=" << slot.id << " live=[" << interval.start << "," << interval.end << "]"
|
||||
<< " logical=" << formatReportMemory(interval.size)
|
||||
<< " slot_size=" << formatReportMemory(interval.physicalSlotSize) << " addr=" << interval.assignedAddress
|
||||
<< "\n";
|
||||
os << " value=" << summarizeValue(allocValue, 88) << "\n";
|
||||
os << " type=" << allocValue.getType() << "\n";
|
||||
os << " loc="
|
||||
<< summarizeLocation(definingOp ? definingOp->getLoc() : UnknownLoc::get(coreLikeOp->getContext())) << "\n";
|
||||
os << " nested=" << (interval.insideNestedRegion ? "yes" : "no")
|
||||
<< " escapes_loop=" << (interval.escapesLoop ? "yes" : "no")
|
||||
<< " start_fallback=" << (interval.startUsedAllocFallback ? "yes" : "no")
|
||||
<< " end_fallback=" << (interval.endUsedFallback ? "yes" : "no") << "\n";
|
||||
os << " first_use=" << summarizeOperation(interval.firstTouchOp) << " @" << interval.firstTouchPosition
|
||||
<< "\n";
|
||||
os << " last_use=" << summarizeOperation(interval.lastTouchOp) << " @" << interval.lastTouchPosition << "\n";
|
||||
os << " slot_peers=";
|
||||
bool first = true;
|
||||
for (size_t otherIndex : slot.intervalIndices) {
|
||||
if (intervals[otherIndex].id == interval.id)
|
||||
continue;
|
||||
if (!first)
|
||||
os << ", ";
|
||||
os << "#" << intervals[otherIndex].id;
|
||||
first = false;
|
||||
}
|
||||
if (first)
|
||||
os << "<none>";
|
||||
os << "\n";
|
||||
if (!interval.fallbackReason.empty())
|
||||
os << " fallback_reason=" << interval.fallbackReason << "\n";
|
||||
if (!interval.aliasesFollowed.empty()) {
|
||||
os << " aliases_followed=" << interval.aliasesFollowed.size() << "\n";
|
||||
for (const std::string& alias : interval.aliasesFollowed)
|
||||
os << " - " << abbreviate(collapseWhitespace(alias), 108) << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
os.flush();
|
||||
|
||||
return artifacts;
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct LocalAllocInterval {
|
||||
size_t id = 0;
|
||||
mlir::memref::AllocOp alloc;
|
||||
MemoryValueKey key;
|
||||
uint64_t start = 0;
|
||||
uint64_t end = 0;
|
||||
size_t size = 0;
|
||||
mlir::Operation* startOp = nullptr;
|
||||
mlir::Operation* endOp = nullptr;
|
||||
mlir::Operation* firstTouchOp = nullptr;
|
||||
mlir::Operation* lastTouchOp = nullptr;
|
||||
uint64_t firstTouchPosition = 0;
|
||||
uint64_t lastTouchPosition = 0;
|
||||
bool startUsedAllocFallback = false;
|
||||
bool endUsedFallback = false;
|
||||
bool hasRuntimeUse = false;
|
||||
bool insideNestedRegion = false;
|
||||
bool escapesLoop = false;
|
||||
std::string fallbackReason;
|
||||
llvm::SmallVector<std::string, 8> aliasesFollowed;
|
||||
size_t slotPlanIndex = std::numeric_limits<size_t>::max();
|
||||
size_t physicalSlotId = std::numeric_limits<size_t>::max();
|
||||
size_t assignedAddress = 0;
|
||||
size_t physicalSlotSize = 0;
|
||||
};
|
||||
|
||||
struct PlannedPhysicalSlot {
|
||||
size_t id = std::numeric_limits<size_t>::max();
|
||||
size_t requiredSize = 0;
|
||||
size_t size = 0;
|
||||
size_t address = 0;
|
||||
llvm::SmallVector<size_t, 8> intervalIndices;
|
||||
};
|
||||
|
||||
llvm::SmallVector<LocalAllocInterval, 0> buildLocalAllocIntervals(mlir::Operation* coreLikeOp,
|
||||
std::optional<unsigned> lane);
|
||||
|
||||
llvm::SmallVector<PlannedPhysicalSlot, 0> planPhysicalSlots(llvm::MutableArrayRef<LocalAllocInterval> intervals);
|
||||
|
||||
MemoryPlanArtifacts buildMemoryPlanArtifacts(mlir::Operation* coreLikeOp,
|
||||
std::optional<unsigned> lane,
|
||||
llvm::ArrayRef<LocalAllocInterval> intervals,
|
||||
llvm::ArrayRef<PlannedPhysicalSlot> slots,
|
||||
size_t addressLimit,
|
||||
PimMemoryReportLevel reportLevel);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,93 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "Common/Support/CheckedArithmetic.hpp"
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {} // namespace
|
||||
|
||||
WeightEmissionResult createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) {
|
||||
auto coreWeightsDirPath = outputDirPath + "/weights";
|
||||
auto error = sys::fs::create_directory(coreWeightsDirPath);
|
||||
assert(!error && "Error creating weights directory");
|
||||
size_t indexFileName = 0;
|
||||
|
||||
int64_t xbarSize = crossbarSize.getValue();
|
||||
WeightEmissionResult result;
|
||||
llvm::SmallVector<std::pair<ResolvedWeightView, std::string>, 16> materializedWeights;
|
||||
|
||||
auto materializeWeight = [&](const ResolvedWeightView& weightView) -> std::string {
|
||||
if (auto it = llvm::find_if(materializedWeights, [&](const auto& entry) { return entry.first == weightView; });
|
||||
it != materializedWeights.end())
|
||||
return it->second;
|
||||
|
||||
auto globalOp = weightView.globalOp;
|
||||
auto denseAttr = mlir::dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
assert(denseAttr && "Weight global must have dense initial value");
|
||||
|
||||
ArrayRef<int64_t> shape = weightView.shape;
|
||||
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
|
||||
int64_t numRows = shape[0];
|
||||
int64_t numCols = shape[1];
|
||||
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
||||
|
||||
size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
|
||||
|
||||
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
||||
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
||||
std::error_code errorCode;
|
||||
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
|
||||
if (errorCode) {
|
||||
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
|
||||
assert(errorCode);
|
||||
}
|
||||
|
||||
uint64_t zero = 0;
|
||||
for (int64_t row = 0; row < xbarSize; row++) {
|
||||
for (int64_t col = 0; col < xbarSize; col++) {
|
||||
if (row < numRows && col < numCols) {
|
||||
int64_t elementIndex = weightView.offset + row * weightView.strides[0] + col * weightView.strides[1];
|
||||
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
|
||||
uint64_t word = bits.getZExtValue();
|
||||
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
|
||||
}
|
||||
else {
|
||||
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
weightFileStream.close();
|
||||
materializedWeights.push_back({weightView, newFileName});
|
||||
uint64_t weightBytes = pim::checkedMulOrCrash(
|
||||
pim::checkedMulOrCrash(static_cast<size_t>(xbarSize), static_cast<size_t>(xbarSize), "weight element count"),
|
||||
elementByteWidth,
|
||||
"weight byte size");
|
||||
result.totalWeightBytes = pim::checkedAddOrCrash(result.totalWeightBytes, weightBytes, "total weight bytes");
|
||||
return newFileName;
|
||||
};
|
||||
|
||||
for (const WeightFileRequest& request : requests) {
|
||||
auto& coreFiles = result.mapCoreWeightToFileName[request.coreId];
|
||||
coreFiles.reserve(request.weights.size());
|
||||
for (const ResolvedWeightView& weight : request.weights)
|
||||
coreFiles.push_back(materializeWeight(weight));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,29 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct WeightFileRequest {
|
||||
size_t coreId = 0;
|
||||
llvm::SmallVector<ResolvedWeightView, 8> weights;
|
||||
};
|
||||
|
||||
struct WeightEmissionResult {
|
||||
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>> mapCoreWeightToFileName;
|
||||
uint64_t totalWeightBytes = 0;
|
||||
};
|
||||
|
||||
WeightEmissionResult createAndPopulateWeightFolder(llvm::ArrayRef<WeightFileRequest> requests,
|
||||
llvm::StringRef outputDirPath);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -3,6 +3,12 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||
|
||||
add_pim_library(OMONNXToSpatial
|
||||
Patterns.cpp
|
||||
CompileTime.cpp
|
||||
ONNXToSpatialVerifier.cpp
|
||||
Patterns/Pre.cpp
|
||||
Patterns/Post.cpp
|
||||
Patterns/GeneratedConversion.cpp
|
||||
Patterns/Math/Conv.cpp
|
||||
Patterns/Math/Elementwise.cpp
|
||||
Patterns/Math/Gemm.cpp
|
||||
@@ -16,9 +22,15 @@ add_pim_library(OMONNXToSpatial
|
||||
Patterns/Tensor/Gather.cpp
|
||||
Patterns/Tensor/Resize.cpp
|
||||
Patterns/Tensor/Reshape.cpp
|
||||
Patterns/Tensor/Slice.cpp
|
||||
Patterns/Tensor/Split.cpp
|
||||
Patterns/Tensor/Transpose.cpp
|
||||
ONNXToSpatialPass.cpp
|
||||
Common.cpp
|
||||
Common/AttributeUtils.cpp
|
||||
Common/ComputeRegionBuilder.cpp
|
||||
Common/IndexingUtils.cpp
|
||||
Common/ShapeTilingUtils.cpp
|
||||
Common/WeightMaterialization.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
@@ -26,6 +38,7 @@ add_pim_library(OMONNXToSpatial
|
||||
ONNXToSpatialIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRLinalgDialect
|
||||
MLIRSCFDialect
|
||||
MLIRTosaDialect
|
||||
OMCompilerOptions
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "Common.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
SmallVector<Value> sliceTensor(
|
||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
||||
assert("Invalid axis" && axis < shape.size());
|
||||
|
||||
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(shape.size());
|
||||
for (const auto size : shape)
|
||||
sizes.push_back(rewriter.getIndexAttr(size));
|
||||
sizes[axis] = rewriter.getIndexAttr(sliceSize);
|
||||
|
||||
long length = shape[axis];
|
||||
auto [numSlices, lastSliceSize] = ceilIntegerDivideWithRemainder(length, sliceSize);
|
||||
SmallVector<Value> slices;
|
||||
slices.reserve(numSlices);
|
||||
|
||||
for (int64_t i = 0; i < numSlices; i++) {
|
||||
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
|
||||
if (i == numSlices - 1 && lastSliceSize != 0)
|
||||
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
|
||||
|
||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
||||
slices.push_back(slice);
|
||||
}
|
||||
|
||||
return slices;
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
ArrayRef<long> shape = getTensorShape(vectorToSlice);
|
||||
assert("Not a vector" && isVectorShape(shape));
|
||||
size_t axis = shape[0] != 1 ? 0 : 1;
|
||||
return sliceTensor(vectorToSlice, axis, sliceSize, rewriter, loc);
|
||||
}
|
||||
|
||||
DenseMap<CoreId, SmallVector<Value>>
|
||||
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
|
||||
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
|
||||
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
|
||||
size_t coreId = sliceId / crossbarCountInCore;
|
||||
slicesPerCore[coreId].push_back(slices[sliceId]);
|
||||
}
|
||||
return slicesPerCore;
|
||||
}
|
||||
|
||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
|
||||
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc) {
|
||||
assert("Not a matrix" && isMatrixShape(getTensorShape(matrixToTile)));
|
||||
|
||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tiles;
|
||||
|
||||
SmallVector<Value> hSlices = sliceTensor(matrixToTile, 1, hSliceSize, rewriter, loc);
|
||||
size_t numHSlices = hSlices.size();
|
||||
for (size_t hSliceId = 0; hSliceId < numHSlices; hSliceId++) {
|
||||
Value hSlice = hSlices[hSliceId];
|
||||
SmallVector<Value> vSlices = sliceTensor(hSlice, 0, vSliceSize, rewriter, loc);
|
||||
for (size_t vSliceId = 0; vSliceId < vSlices.size(); vSliceId++) {
|
||||
size_t coreId = vSliceId / crossbarCountInCore;
|
||||
Value vSlice = vSlices[vSliceId];
|
||||
tiles[hSliceId][coreId].push_back(vSlice);
|
||||
}
|
||||
}
|
||||
return tiles;
|
||||
}
|
||||
|
||||
tensor::SplatOp
|
||||
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
|
||||
Type elementType = oldType.getElementType();
|
||||
int64_t shape[2] = {1, length};
|
||||
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
||||
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
SmallVector<Value> index(oldType.getRank(), zero);
|
||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
|
||||
|
||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||
}
|
||||
|
||||
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
||||
if (tensors.size() == 1)
|
||||
return tensors[0];
|
||||
|
||||
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
|
||||
SmallVector<Value> tensors2;
|
||||
tensors2.reserve(tensors.size() / 2);
|
||||
|
||||
auto* currTensors = &tensors1;
|
||||
auto* nextTensors = &tensors2;
|
||||
while (currTensors->size() > 1) {
|
||||
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
|
||||
Value a = (*currTensors)[i];
|
||||
Value b = (*currTensors)[i + 1];
|
||||
rewriter.setInsertionPointAfterValue(b);
|
||||
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
|
||||
nextTensors->push_back(addedValue);
|
||||
}
|
||||
if (currTensors->size() % 2 == 1)
|
||||
nextTensors->push_back(currTensors->back());
|
||||
std::swap(currTensors, nextTensors);
|
||||
nextTensors->clear();
|
||||
}
|
||||
assert(currTensors->size() == 1 && "Expected a single input at this point.");
|
||||
return (*currTensors)[0];
|
||||
}
|
||||
|
||||
}; // namespace onnx_mlir
|
||||
@@ -1,279 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageWidth(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(2);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageHeight(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(3);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageChannel(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(1);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageN(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(0);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getKernelWidth(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(2);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getKernelHeight(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(3);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getFilterCount(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(0);
|
||||
}
|
||||
|
||||
using HSliceId = size_t;
|
||||
using CoreId = size_t;
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr C ceilIntegerDivide(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||
C ac = static_cast<C>(a);
|
||||
C bc = static_cast<C>(b);
|
||||
return 1 + (ac - 1) / bc;
|
||||
}
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||
C ac = static_cast<C>(a);
|
||||
C bc = static_cast<C>(b);
|
||||
return {ceilIntegerDivide(ac, bc), ac % bc};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isMatrixShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && shape[0] == 1;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isVVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && shape[1] == 1;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T getVectorLength(mlir::ArrayRef<T> shape) {
|
||||
assert(isVectorShape(shape));
|
||||
return shape[0] != 1 ? shape[0] : shape[1];
|
||||
}
|
||||
|
||||
inline auto getTensorShape(mlir::Value tensor) {
|
||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||
}
|
||||
|
||||
inline bool isWeightLikeComputeOperand(mlir::Value value) {
|
||||
auto rankedType = mlir::dyn_cast<mlir::RankedTensorType>(value.getType());
|
||||
if (!rankedType || !isMatrixShape(rankedType.getShape()))
|
||||
return false;
|
||||
|
||||
llvm::SmallPtrSet<mlir::Operation*, 8> visited;
|
||||
|
||||
while (auto* definingOp = value.getDefiningOp()) {
|
||||
if (!visited.insert(definingOp).second)
|
||||
return false;
|
||||
if (hasWeightAlways(definingOp))
|
||||
return true;
|
||||
|
||||
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(definingOp)) {
|
||||
value = extractSliceOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(definingOp)) {
|
||||
value = expandShapeOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseShapeOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(definingOp)) {
|
||||
value = transposeOp.getData();
|
||||
continue;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||
}
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(values[Is]...);
|
||||
}
|
||||
|
||||
template <size_t>
|
||||
using ValueArg = mlir::Value;
|
||||
|
||||
template <typename Fn, typename Seq>
|
||||
struct InvokeWithBlockArgsResult;
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
|
||||
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
|
||||
};
|
||||
|
||||
template <typename Fn, typename Seq>
|
||||
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
|
||||
|
||||
template <typename Fn>
|
||||
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||
auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
computeOp.getBody().push_back(block);
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
|
||||
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
||||
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
|
||||
auto bodyResult =
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
|
||||
}
|
||||
else {
|
||||
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RewriterT, typename BodyFn>
|
||||
auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
auto computeOp = spatial::SpatWeightedCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
computeOp.getBody().push_back(block);
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
|
||||
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
||||
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
|
||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatWeightedCompute>(computeOp);
|
||||
}
|
||||
else {
|
||||
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
|
||||
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||
size_t axis,
|
||||
int64_t sliceSize,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
||||
int64_t sliceSize,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
||||
|
||||
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
|
||||
tileMatrix(mlir::Value& matrixToTile,
|
||||
int64_t hSliceSize,
|
||||
int64_t vSliceSize,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location& loc);
|
||||
|
||||
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
|
||||
int64_t length,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
||||
|
||||
}; // namespace onnx_mlir
|
||||
@@ -0,0 +1,23 @@
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
|
||||
#include "AttributeUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
int64_t getI64Attr(ArrayAttr attr, size_t index) { return cast<IntegerAttr>(attr[index]).getInt(); }
|
||||
|
||||
int64_t getOptionalI64Attr(std::optional<ArrayAttr> attr, size_t index, int64_t defaultValue) {
|
||||
return attr ? getI64Attr(*attr, index) : defaultValue;
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t> getI64ArrayAttrValues(ArrayAttr attr) {
|
||||
llvm::SmallVector<int64_t> values;
|
||||
values.reserve(attr.size());
|
||||
for (Attribute value : attr)
|
||||
values.push_back(cast<IntegerAttr>(value).getInt());
|
||||
return values;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
int64_t getI64Attr(mlir::ArrayAttr attr, size_t index);
|
||||
|
||||
int64_t getOptionalI64Attr(std::optional<mlir::ArrayAttr> attr, size_t index, int64_t defaultValue);
|
||||
|
||||
llvm::SmallVector<int64_t> getI64ArrayAttrValues(mlir::ArrayAttr attr);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "AttributeUtils.hpp"
|
||||
#include "ComputeRegionBuilder.hpp"
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
@@ -0,0 +1,39 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "ComputeRegionBuilder.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
|
||||
if (tensors.size() == 1)
|
||||
return tensors[0];
|
||||
|
||||
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
|
||||
SmallVector<Value> tensors2;
|
||||
tensors2.reserve(tensors.size() / 2);
|
||||
|
||||
auto* currTensors = &tensors1;
|
||||
auto* nextTensors = &tensors2;
|
||||
while (currTensors->size() > 1) {
|
||||
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
|
||||
Value a = (*currTensors)[i];
|
||||
Value b = (*currTensors)[i + 1];
|
||||
rewriter.setInsertionPointAfterValue(b);
|
||||
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
|
||||
nextTensors->push_back(addedValue);
|
||||
}
|
||||
if (currTensors->size() % 2 == 1)
|
||||
nextTensors->push_back(currTensors->back());
|
||||
std::swap(currTensors, nextTensors);
|
||||
nextTensors->clear();
|
||||
}
|
||||
assert(currTensors->size() == 1 && "Expected a single input at this point.");
|
||||
return (*currTensors)[0];
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,267 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace detail {
|
||||
|
||||
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
||||
|
||||
inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) {
|
||||
return mlir::ValueRange(block->getArguments()).drop_front(weightCount);
|
||||
}
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||
}
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(values[Is]...);
|
||||
}
|
||||
|
||||
template <size_t>
|
||||
using ValueArg = mlir::Value;
|
||||
|
||||
template <typename Fn, typename Seq>
|
||||
struct InvokeWithBlockArgsResult;
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
|
||||
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
|
||||
};
|
||||
|
||||
template <typename Fn, typename Seq>
|
||||
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
|
||||
|
||||
template <typename Fn>
|
||||
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
||||
|
||||
struct SpatComputeBatchBodyArgs {
|
||||
mlir::Value lane;
|
||||
mlir::ValueRange weights;
|
||||
mlir::ValueRange inputs;
|
||||
mlir::ValueRange outputs;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename RewriterT>
|
||||
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
|
||||
assert(!inputs.empty() && "spat.concat requires at least one input");
|
||||
if (inputs.size() == 1)
|
||||
return inputs.front();
|
||||
|
||||
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
|
||||
auto outputShape = llvm::to_vector(firstType.getShape());
|
||||
int64_t concatDimSize = 0;
|
||||
bool concatDimDynamic = false;
|
||||
|
||||
for (mlir::Value input : inputs) {
|
||||
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
|
||||
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
|
||||
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
|
||||
concatDimDynamic = true;
|
||||
else
|
||||
concatDimSize += inputType.getDimSize(axis);
|
||||
}
|
||||
|
||||
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
|
||||
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
|
||||
}
|
||||
|
||||
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
|
||||
/// the body callback reports failure.
|
||||
template <size_t NumInputs, typename RewriterT, typename BodyFn>
|
||||
auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value weight : weights)
|
||||
block->addArgument(weight.getType(), loc);
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
computeOp.getBody().push_back(block);
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
|
||||
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
detail::invokeWithValues(std::forward<BodyFn>(body),
|
||||
detail::getInputBlockArgs(block, weights.size()),
|
||||
std::make_index_sequence<NumInputs> {});
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
else {
|
||||
auto bodyResult = detail::invokeWithValues(std::forward<BodyFn>(body),
|
||||
detail::getInputBlockArgs(block, weights.size()),
|
||||
std::make_index_sequence<NumInputs> {});
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||
}
|
||||
}
|
||||
|
||||
/// Builds a `spat.compute` whose body consumes the block arguments as a single
|
||||
/// `ValueRange`, which is convenient for variadic reductions/concats.
|
||||
template <typename RewriterT, typename BodyFn>
|
||||
auto createSpatCompute(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value weight : weights)
|
||||
block->addArgument(weight.getType(), loc);
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
computeOp.getBody().push_back(block);
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
|
||||
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
else {
|
||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RewriterT, typename BodyFn>
|
||||
auto createSpatComputeBatch(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
int64_t laneCount,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
||||
|
||||
auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count");
|
||||
if (mlir::failed(laneCountAttr))
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
||||
|
||||
auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
|
||||
|
||||
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
|
||||
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
|
||||
blockArgTypes.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
|
||||
blockArgLocs.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
|
||||
for (mlir::Value weight : weights) {
|
||||
blockArgTypes.push_back(weight.getType());
|
||||
blockArgLocs.push_back(weight.getLoc());
|
||||
}
|
||||
for (mlir::Value input : inputs) {
|
||||
blockArgTypes.push_back(input.getType());
|
||||
blockArgLocs.push_back(input.getLoc());
|
||||
}
|
||||
for (mlir::Type resultType : resultTypes) {
|
||||
blockArgTypes.push_back(resultType);
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
|
||||
auto* block =
|
||||
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), mlir::TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
|
||||
detail::SpatComputeBatchBodyArgs args {
|
||||
block->getArgument(0),
|
||||
mlir::ValueRange(block->getArguments()).slice(1, weights.size()),
|
||||
mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()),
|
||||
mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size())};
|
||||
|
||||
using BodyResult = std::invoke_result_t<BodyFn, detail::SpatComputeBatchBodyArgs>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
std::forward<BodyFn>(body)(args);
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
|
||||
}
|
||||
else {
|
||||
auto bodyResult = std::forward<BodyFn>(body)(args);
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
rewriter.eraseOp(batchOp);
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
|
||||
}
|
||||
}
|
||||
|
||||
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
mlir::Value dest,
|
||||
mlir::ArrayRef<mlir::OpFoldResult> offsets,
|
||||
mlir::ArrayRef<mlir::OpFoldResult> sizes,
|
||||
mlir::ArrayRef<mlir::OpFoldResult> strides) {
|
||||
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
mlir::tensor::ParallelInsertSliceOp::create(rewriter, loc, source, dest, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
template <typename BodyFn>
|
||||
mlir::Value materializeOrComputeUnary(mlir::Value input,
|
||||
mlir::RankedTensorType resultType,
|
||||
mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
BodyFn&& build) {
|
||||
auto&& buildFn = build;
|
||||
if (isCompileTimeComputable(input))
|
||||
return buildFn(input);
|
||||
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) {
|
||||
mlir::Value result = buildFn(computeInput);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,45 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "IndexingUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
||||
|
||||
FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank) {
|
||||
int64_t normalizedAxis = normalizeAxis(axis, rank);
|
||||
if (normalizedAxis < 0 || normalizedAxis >= rank)
|
||||
return failure();
|
||||
return normalizedAxis;
|
||||
}
|
||||
|
||||
int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
|
||||
|
||||
static SmallVector<int64_t> normalizeAxesImpl(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
||||
SmallVector<int64_t> normalizedAxes;
|
||||
if (!axesAttr) {
|
||||
normalizedAxes.reserve(rank);
|
||||
for (int64_t axis = 0; axis < rank; ++axis)
|
||||
normalizedAxes.push_back(axis);
|
||||
}
|
||||
else {
|
||||
normalizedAxes.reserve(axesAttr->size());
|
||||
for (Attribute attr : *axesAttr)
|
||||
normalizedAxes.push_back(normalizeAxis(cast<IntegerAttr>(attr).getInt(), rank));
|
||||
llvm::sort(normalizedAxes);
|
||||
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
|
||||
}
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
||||
SmallVector<int64_t> normalizedAxes = normalizeAxesImpl(axesAttr, rank);
|
||||
for (int64_t axis : normalizedAxes)
|
||||
if (axis < 0 || axis >= rank)
|
||||
return failure();
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
int64_t normalizeAxis(int64_t axis, int64_t rank);
|
||||
|
||||
mlir::FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank);
|
||||
|
||||
int64_t normalizeIndex(int64_t index, int64_t dimSize);
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,179 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(RankedTensorType type) {
|
||||
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
|
||||
}
|
||||
|
||||
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> permutedShape;
|
||||
permutedShape.reserve(permutation.size());
|
||||
for (int64_t axis : permutation)
|
||||
permutedShape.push_back(shape[axis]);
|
||||
return permutedShape;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> invertPermutation(ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> inversePermutation(permutation.size());
|
||||
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
|
||||
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
|
||||
return inversePermutation;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<ArrayAttr> permAttr, int64_t rank) {
|
||||
SmallVector<int64_t> permutation;
|
||||
if (!permAttr) {
|
||||
permutation.reserve(rank);
|
||||
for (int64_t dim = rank - 1; dim >= 0; --dim)
|
||||
permutation.push_back(dim);
|
||||
return permutation;
|
||||
}
|
||||
|
||||
if (static_cast<int64_t>(permAttr->size()) != rank)
|
||||
return failure();
|
||||
|
||||
permutation.reserve(permAttr->size());
|
||||
SmallVector<bool> seen(rank, false);
|
||||
for (IntegerAttr attr : permAttr->getAsRange<IntegerAttr>()) {
|
||||
int64_t axis = attr.getInt();
|
||||
if (axis < 0 || axis >= rank || seen[axis])
|
||||
return failure();
|
||||
seen[axis] = true;
|
||||
permutation.push_back(axis);
|
||||
}
|
||||
return permutation;
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
|
||||
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getZeroOffsets(PatternRewriter& rewriter, int64_t rank) {
|
||||
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(0));
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int64_t> shape) {
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(shape.size());
|
||||
for (int64_t dim : shape)
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
return sizes;
|
||||
}
|
||||
|
||||
SmallVector<Value> sliceTensor(
|
||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
||||
assert("Invalid axis" && axis < shape.size());
|
||||
|
||||
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, shape.size());
|
||||
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, shape);
|
||||
sizes[axis] = rewriter.getIndexAttr(sliceSize);
|
||||
|
||||
long length = shape[axis];
|
||||
auto [numSlices, lastSliceSize] = ceilIntegerDivideWithRemainder(length, sliceSize);
|
||||
SmallVector<Value> slices;
|
||||
slices.reserve(numSlices);
|
||||
|
||||
for (int64_t i = 0; i < numSlices; i++) {
|
||||
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
|
||||
int64_t currentSliceSize = sliceSize;
|
||||
if (i == numSlices - 1 && lastSliceSize != 0) {
|
||||
currentSliceSize = lastSliceSize;
|
||||
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
|
||||
}
|
||||
|
||||
SmallVector<int64_t> sliceShape(shape.begin(), shape.end());
|
||||
sliceShape[axis] = currentSliceSize;
|
||||
auto sliceType =
|
||||
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
|
||||
|
||||
Value slice;
|
||||
if (isCompileTimeComputable(tensorToSlice)) {
|
||||
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
||||
}
|
||||
else {
|
||||
auto sliceCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {sliceType}, {}, ValueRange {tensorToSlice}, [&](Value input) {
|
||||
Value computedSlice = tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, computedSlice);
|
||||
});
|
||||
slice = sliceCompute.getResult(0);
|
||||
}
|
||||
slices.push_back(slice);
|
||||
}
|
||||
|
||||
return slices;
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
ArrayRef<long> shape = getTensorShape(vectorToSlice);
|
||||
assert("Not a vector" && isVectorShape(shape));
|
||||
size_t axis = shape[0] != 1 ? 0 : 1;
|
||||
return sliceTensor(vectorToSlice, axis, sliceSize, rewriter, loc);
|
||||
}
|
||||
|
||||
DenseMap<CoreId, SmallVector<Value>>
|
||||
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
|
||||
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
|
||||
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
|
||||
size_t coreId = sliceId / crossbarCountInCore;
|
||||
slicesPerCore[coreId].push_back(slices[sliceId]);
|
||||
}
|
||||
return slicesPerCore;
|
||||
}
|
||||
|
||||
Value extractAxisSlice(
|
||||
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
SmallVector<int64_t> resultShape(sourceType.getShape());
|
||||
resultShape[axis] = size;
|
||||
auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding());
|
||||
|
||||
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, sourceType.getRank());
|
||||
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(size);
|
||||
return tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value insertStaticSlice(
|
||||
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
return tensor::InsertSliceOp::create(rewriter,
|
||||
loc,
|
||||
source,
|
||||
dest,
|
||||
offsets,
|
||||
getStaticSizes(rewriter, sourceType.getShape()),
|
||||
getUnitStrides(rewriter, sourceType.getRank()))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,114 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
using HSliceId = size_t;
|
||||
using CoreId = size_t;
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr C ceilIntegerDivide(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||
C ac = static_cast<C>(a);
|
||||
C bc = static_cast<C>(b);
|
||||
return 1 + (ac - 1) / bc;
|
||||
}
|
||||
|
||||
template <class A, class B, class C = std::common_type_t<A, B>>
|
||||
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
|
||||
static_assert(std::is_integral_v<A>, "A must be an integer type");
|
||||
static_assert(std::is_integral_v<B>, "B must be an integer type");
|
||||
C ac = static_cast<C>(a);
|
||||
C bc = static_cast<C>(b);
|
||||
return {ceilIntegerDivide(ac, bc), ac % bc};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isMatrixShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && shape[0] == 1;
|
||||
}
|
||||
|
||||
inline auto getTensorShape(mlir::Value tensor) {
|
||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||
}
|
||||
|
||||
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
|
||||
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
|
||||
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
|
||||
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
|
||||
&& lhsType.getShape() == rhsType.getShape();
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
bool hasStaticPositiveShape(mlir::RankedTensorType type);
|
||||
|
||||
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
|
||||
|
||||
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
|
||||
int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
/// Slices a statically shaped tensor along one axis into contiguous pieces of
|
||||
/// at most `sliceSize` elements.
|
||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||
size_t axis,
|
||||
int64_t sliceSize,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
||||
int64_t sliceSize,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
/// Partitions one logical vector into per-core crossbar-sized slices using the
|
||||
/// current PIM target geometry.
|
||||
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
||||
|
||||
mlir::Value extractAxisSlice(
|
||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
||||
|
||||
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
mlir::Value dest,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> offsets);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,115 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool isWeightLikeComputeOperand(Value value) {
|
||||
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!rankedType || !isMatrixShape(rankedType.getShape()))
|
||||
return false;
|
||||
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
|
||||
while (auto* definingOp = value.getDefiningOp()) {
|
||||
if (!visited.insert(definingOp).second)
|
||||
return false;
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
|
||||
return true;
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
value = extractSliceOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
value = expandShapeOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseShapeOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
|
||||
value = transposeOp.getInput();
|
||||
continue;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
|
||||
if (auto mapped = mapper.lookupOrNull(value))
|
||||
return cast<Value>(mapped);
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
|
||||
if (!tensorType || !tensorType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
|
||||
sizes.reserve(tensorType.getRank());
|
||||
for (int64_t dim : tensorType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
|
||||
auto referencedValue =
|
||||
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
|
||||
mapper.map(value, referencedValue.getResult());
|
||||
return referencedValue.getResult();
|
||||
}
|
||||
|
||||
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(definingOp))
|
||||
return failure();
|
||||
|
||||
IRMapping localMapper;
|
||||
for (Value operand : definingOp->getOperands()) {
|
||||
if (auto mapped = mapper.lookupOrNull(operand)) {
|
||||
localMapper.map(operand, cast<Value>(mapped));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (isWeightLikeComputeOperand(operand)) {
|
||||
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
|
||||
if (failed(clonedOperand))
|
||||
return failure();
|
||||
localMapper.map(operand, *clonedOperand);
|
||||
continue;
|
||||
}
|
||||
|
||||
localMapper.map(operand, operand);
|
||||
}
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
|
||||
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
|
||||
mapper.map(oldResult, newResult);
|
||||
|
||||
auto mapped = mapper.lookupOrNull(value);
|
||||
if (!mapped)
|
||||
return failure();
|
||||
return cast<Value>(mapped);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/// Returns true when a matrix-valued compute operand is ultimately backed by a
|
||||
/// weight-marked constant/view chain and can be promoted into weights.
|
||||
bool isWeightLikeComputeOperand(mlir::Value value);
|
||||
|
||||
/// Rebuilds the view/transpose chain of a promoted weight operand inside a new
|
||||
/// compute body while reusing already-materialized intermediate values.
|
||||
llvm::FailureOr<mlir::Value>
|
||||
materializeWeightLikeValueInBlock(mlir::Value value, mlir::IRRewriter& rewriter, mlir::IRMapping& mapper);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,299 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
|
||||
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
|
||||
}
|
||||
|
||||
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
|
||||
return llvm::all_of(extractOp.getIndices(), [](Value index) { return matchConstantIndexValue(index).has_value(); });
|
||||
}
|
||||
|
||||
static bool isStaticTensorResult(Operation* op) {
|
||||
return llvm::all_of(op->getResultTypes(), [](Type type) {
|
||||
auto shapedType = dyn_cast<ShapedType>(type);
|
||||
return shapedType && shapedType.hasStaticShape();
|
||||
});
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!tensorType)
|
||||
return failure();
|
||||
|
||||
int64_t rank = tensorType.getRank();
|
||||
if (static_cast<int64_t>(perms.size()) != rank)
|
||||
return failure();
|
||||
|
||||
llvm::SmallBitVector seen(rank);
|
||||
SmallVector<int64_t> transposedShape;
|
||||
transposedShape.reserve(rank);
|
||||
for (int64_t perm : perms) {
|
||||
if (perm < 0 || perm >= rank || seen.test(perm))
|
||||
return failure();
|
||||
seen.set(perm);
|
||||
transposedShape.push_back(tensorType.getShape()[perm]);
|
||||
}
|
||||
|
||||
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding());
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> transposedValues(originalValues.size());
|
||||
SmallVector<int64_t> originalStrides = computeRowMajorStrides(tensorType.getShape());
|
||||
SmallVector<int64_t> transposedStrides = computeRowMajorStrides(transposedShape);
|
||||
SmallVector<int64_t> originalIndices(rank);
|
||||
|
||||
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
|
||||
int64_t remaining = static_cast<int64_t>(linearIndex);
|
||||
for (int64_t dim = 0; dim < rank; ++dim) {
|
||||
originalIndices[dim] = remaining / originalStrides[dim];
|
||||
remaining %= originalStrides[dim];
|
||||
}
|
||||
|
||||
int64_t transposedLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < rank; ++dim)
|
||||
transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim];
|
||||
|
||||
transposedValues[transposedLinearIndex] = value;
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(transposedType, transposedValues);
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements())
|
||||
return failure();
|
||||
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> values(denseAttr.getValues<Attribute>());
|
||||
return DenseElementsAttr::get(resultType, values);
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> extractSliceDenseElements(DenseElementsAttr denseAttr,
|
||||
tensor::ExtractSliceOp extractSliceOp) {
|
||||
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> offsets = extractSliceOp.getStaticOffsets();
|
||||
ArrayRef<int64_t> sizes = extractSliceOp.getStaticSizes();
|
||||
ArrayRef<int64_t> strides = extractSliceOp.getStaticStrides();
|
||||
if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); })
|
||||
|| llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); })
|
||||
|| llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; }))
|
||||
return failure();
|
||||
|
||||
if (denseAttr.isSplat())
|
||||
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
|
||||
|
||||
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
|
||||
SmallVector<Attribute> resultValues;
|
||||
resultValues.reserve(resultType.getNumElements());
|
||||
|
||||
for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) {
|
||||
int64_t remaining = linearIndex;
|
||||
int64_t sourceLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < resultType.getRank(); ++dim) {
|
||||
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
|
||||
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
|
||||
sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim];
|
||||
}
|
||||
resultValues.push_back(sourceValues[sourceLinearIndex]);
|
||||
}
|
||||
|
||||
return DenseElementsAttr::get(resultType, resultValues);
|
||||
}
|
||||
|
||||
static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
|
||||
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
|
||||
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
|
||||
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp || !visited.insert(definingOp).second)
|
||||
return nullptr;
|
||||
|
||||
// Rebuild dense attributes through view-only host-foldable chains so later
|
||||
// lowering stages can still recognize grouped/sliced constants.
|
||||
if (auto denseAttr = getDirectDenseConstantAttr(value))
|
||||
return denseAttr;
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getData(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
|
||||
SmallVector<int64_t> perm;
|
||||
perm.reserve(transposeOp.getPermAttr().size());
|
||||
for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange<IntegerAttr>())
|
||||
perm.push_back(attr.getInt());
|
||||
auto transposedAttr = transposeDenseElements(inputAttr, perm);
|
||||
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
|
||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getInput(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
|
||||
SmallVector<int64_t> perm(transposeOp.getPermutation().begin(), transposeOp.getPermutation().end());
|
||||
auto transposedAttr = transposeDenseElements(inputAttr, perm);
|
||||
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
|
||||
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
|
||||
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
|
||||
return succeeded(slicedAttr) ? *slicedAttr : nullptr;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static std::optional<CompileTimeSource>
|
||||
getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited, size_t chainLength = 0) {
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
|
||||
if (!visited.insert(op).second)
|
||||
return {
|
||||
{op, chainLength}
|
||||
};
|
||||
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
||||
return {
|
||||
{op, chainLength}
|
||||
};
|
||||
|
||||
chainLength += 1;
|
||||
|
||||
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
||||
return hasConstantIndices(extractOp)
|
||||
? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength)
|
||||
: std::nullopt;
|
||||
|
||||
if (!isStaticTensorResult(op))
|
||||
return std::nullopt;
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
|
||||
return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(op))
|
||||
return getCompileTimeSourceImpl(transposeOp.getInput().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
|
||||
return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
|
||||
return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||
return hasStaticUnitStrides(extractSliceOp)
|
||||
? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength)
|
||||
: std::nullopt;
|
||||
|
||||
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
||||
return getCompileTimeSourceImpl(splatOp.getInput().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
||||
return getCompileTimeSourceImpl(extractRowsOp.getInput().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
||||
std::optional<CompileTimeSource> res = {};
|
||||
for (auto operandValue : concatOp.getOperands()) {
|
||||
auto partialRes = getCompileTimeSourceImpl(operandValue.getDefiningOp(), visited, chainLength);
|
||||
if (!partialRes)
|
||||
return std::nullopt;
|
||||
|
||||
if (!res) {
|
||||
res = partialRes;
|
||||
continue;
|
||||
}
|
||||
if (res->chainLength < partialRes->chainLength)
|
||||
res = partialRes;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::optional<CompileTimeSource> getCompileTimeSource(Operation* op) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return getCompileTimeSourceImpl(op, visited);
|
||||
}
|
||||
|
||||
bool isCompileTimeComputable(Value value) {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return false;
|
||||
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return getCompileTimeSourceImpl(definingOp, visited).has_value();
|
||||
}
|
||||
|
||||
bool isCompileTimeOp(Operation* op) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return getCompileTimeSourceImpl(op, visited).has_value();
|
||||
}
|
||||
|
||||
DenseElementsAttr getHostConstDenseElementsAttr(Value value) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return getHostConstantDenseElementsAttrImpl(value, visited);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct CompileTimeSource {
|
||||
mlir::Operation* source;
|
||||
size_t chainLength;
|
||||
};
|
||||
|
||||
std::optional<CompileTimeSource> getCompileTimeSource(mlir::Operation* op);
|
||||
|
||||
bool isCompileTimeComputable(mlir::Value value);
|
||||
|
||||
bool isCompileTimeOp(mlir::Operation* op);
|
||||
|
||||
mlir::DenseElementsAttr getHostConstDenseElementsAttr(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user