99 Commits

Author SHA1 Message Date
ilgeco 75fb70712f CodexWorkaround
Validate Operations / validate-operations (push) Has been cancelled
2026-06-08 11:33:36 +02:00
NiccoloN aec80529ca much faster MaterializeMergeSchedule.cpp
Validate Operations / validate-operations (push) Has been cancelled
2026-06-05 18:22:59 +02:00
ilgeco 8ddbbcecfa Added support for SliceOp
Validate Operations / validate-operations (push) Has been cancelled
2026-06-05 17:36:51 +02:00
ilgeco 90c4339808 SpatialSubOp
Validate Operations / validate-operations (push) Has been cancelled
2026-06-05 17:12:16 +02:00
ilgeco 08870de1a6 Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone 2026-06-05 16:43:50 +02:00
NiccoloN a34ac223c0 fix remaining failing tests
Validate Operations / validate-operations (push) Has been cancelled
remove unsupported tests
2026-06-05 15:27:11 +02:00
NiccoloN 0fa10b4074 better Conv.cpp and fixed broken conv op validation test
Validate Operations / validate-operations (push) Has been cancelled
2026-06-05 13:35:27 +02:00
NiccoloN e166ff7e1d better AGENTS.md
Validate Operations / validate-operations (push) Has been cancelled
2026-06-05 11:36:01 +02:00
ilgeco a70a8f77cf Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone 2026-06-05 10:20:09 +02:00
ilgeco 800c0c4316 Python peft and new summary report 2026-06-05 10:20:02 +02:00
NiccoloN 1e9e61f5a9 remove useless MaterializeHostConstantsPass.cpp and fix lowering before instead
Validate Operations / validate-operations (push) Has been cancelled
avoid spammy pim codegen diagnostics
2026-06-05 10:06:28 +02:00
ilgeco 27410207c4 New corner case test
Validate Operations / validate-operations (push) Has been cancelled
2026-06-04 16:00:48 +02:00
NiccoloN cbc9808229 more generalized MaterializeMergeSchedule.cpp for better memory usage after materialization
Validate Operations / validate-operations (push) Has been cancelled
2026-06-04 12:44:57 +02:00
NiccoloN 69021d56aa automatic code reformat
Validate Operations / validate-operations (push) Has been cancelled
2026-06-03 19:43:56 +02:00
NiccoloN dc5edd032c Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone 2026-06-03 19:40:53 +02:00
NiccoloN e33f517221 faster scheduling: split batches into numCores tasks before scheduling instead of numLanes tasks 2026-06-03 19:40:34 +02:00
ilgeco f94b3d1020 Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone
Validate Operations / validate-operations (push) Has been cancelled
2026-06-03 18:15:33 +02:00
ilgeco 20cf40c9ba Memory Liveness 2026-06-03 18:15:30 +02:00
NiccoloN 37a59054a5 better loop compaction in MaterializeMergeSchedule.cpp
Validate Operations / validate-operations (push) Has been cancelled
2026-06-03 16:01:19 +02:00
ilgeco 2a8faf9c6b Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone 2026-06-03 13:49:42 +02:00
ilgeco 01b9d03fc6 Early warning on memory address 2026-06-03 13:49:39 +02:00
NiccoloN 501e6c76f3 better memory report
Validate Operations / validate-operations (push) Has been cancelled
capped vector allocations at u32::MAX in rust simulator
2026-06-03 13:48:42 +02:00
ilgeco 3c2667f11e Fix memory bug
Validate Operations / validate-operations (push) Has been cancelled
2026-06-03 12:59:58 +02:00
NiccoloN 0a5e73c3ea better transpose pattern and cleanup
Validate Operations / validate-operations (push) Has been cancelled
2026-06-03 12:26:31 +02:00
NiccoloN 636310d0cb add shared loop creation helpers
Validate Operations / validate-operations (push) Has been cancelled
add shared checked arithmetic helpers
refactor pim passes into Pim/Transforms
more robust memory coalescing pass
2026-06-01 16:49:06 +02:00
NiccoloN 356be6ccc2 uniquify constants produced by affine lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-06-01 10:52:25 +02:00
NiccoloN b678e55d3c compact memory contiguity with for loops
Validate Operations / validate-operations (push) Has been cancelled
2026-05-31 18:47:59 +02:00
NiccoloN ab63498f3f normalize affine arithmetic helpers
Validate Operations / validate-operations (push) Has been cancelled
2026-05-30 16:37:28 +02:00
NiccoloN 7c3943bd06 Merge remote-tracking branch 'origin/refactorone' into refactorone
Validate Operations / validate-operations (push) Has been cancelled
# Conflicts:
#	src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp
2026-05-30 16:12:42 +02:00
NiccoloN c0238c0d06 fix high memory usage caused by MaterializeMergeSchedule.cpp with more robust code 2026-05-30 16:12:06 +02:00
NiccoloN ff36729140 centralize logic for materializing contiguous memory into bufferization
fix codegen symlinks overwrite
remove deprecated pim memcp_hd_batch op
2026-05-30 16:09:58 +02:00
NiccoloN cf93caecd5 centralize logic for materializing contiguous memory into bufferization
Validate Operations / validate-operations (push) Has been cancelled
fix codegen symlinks overwrite
remove deprecated pim memcp_hd_batch op
2026-05-30 15:54:24 +02:00
NiccoloN 2d5b03c08f automatic code reformat
Validate Operations / validate-operations (push) Has been cancelled
2026-05-29 19:21:37 +02:00
NiccoloN a41f694cf0 batched matmul pattern
Validate Operations / validate-operations (push) Has been cancelled
add conv helpers
new validation tests for matmul
2026-05-29 19:09:48 +02:00
NiccoloN 8bb0babf1b finish helper refactoring
Validate Operations / validate-operations (push) Has been cancelled
use uniqued constant helpers everywhere
materialize transposed constants directly
2026-05-29 17:05:45 +02:00
ilgeco 819d8af0f7 Refactor + ReduceMean batched
Validate Operations / validate-operations (push) Has been cancelled
2026-05-29 15:57:13 +02:00
ilgeco 832bd7f1f7 Transpose and Refactor of Patterns
Validate Operations / validate-operations (push) Has been cancelled
2026-05-29 13:23:31 +02:00
ilgeco 82b44a6387 New Onnx test gemm model 2026-05-29 11:41:30 +02:00
ilgeco 7fcc765d6e New Onnx Test model 2026-05-29 11:37:17 +02:00
ilgeco f34698a2b6 Validate new option for compile only
Validate Operations / validate-operations (push) Has been cancelled
2026-05-28 22:59:26 +02:00
ilgeco 1ab489fe0a Dynamic gemm/conv 2026-05-28 18:00:14 +02:00
ilgeco cbf7b235f1 pim-simulator now support usize addresses
Validate Operations / validate-operations (push) Has been cancelled
2026-05-28 17:03:19 +02:00
NiccoloN 00414dd1d9 add verification of communication invariants at the end of spatial
Validate Operations / validate-operations (push) Has been cancelled
remove dead logic
2026-05-27 19:17:48 +02:00
NiccoloN 783dffe553 fix scheduling cost model
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 17:14:19 +02:00
NiccoloN 874a2f53e6 automatic code reformat
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 16:39:56 +02:00
NiccoloN 4bdaa57656 simplify affine maps to constants where possible
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 16:39:27 +02:00
NiccoloN 1a5d7d2a3f fix bufferization and weight emission after new gemm patterns
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 16:15:10 +02:00
ilgeco 013ae0ac2a Update README and AGENTS
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 15:09:30 +02:00
ilgeco c6b02af7a9 Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone 2026-05-27 14:32:51 +02:00
ilgeco d2048bd394 Add to gitignore 2026-05-27 14:32:47 +02:00
NiccoloN 158f0f0c54 update AGENTS.md
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 14:32:04 +02:00
NiccoloN 532cac8246 commit AGENTS.md
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 14:07:34 +02:00
NiccoloN d609e84054 teh only weight (WIP)
Validate Operations / validate-operations (push) Has been cancelled
2026-05-26 18:42:14 +02:00
NiccoloN addfc8a86e remove other dead logic
Validate Operations / validate-operations (push) Has been cancelled
2026-05-25 21:22:08 +02:00
NiccoloN 0f240af271 cleanup unused channel operations and related logic
Validate Operations / validate-operations (push) Has been cancelled
2026-05-25 20:58:51 +02:00
ilgeco bdc4ca33f3 No extract no more
Validate Operations / validate-operations (push) Has been cancelled
2026-05-25 18:19:43 +02:00
ilgeco b79c333c6c Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone 2026-05-25 15:44:40 +02:00
ilgeco eea9261c7b Bye Bye DCP 2026-05-25 15:44:30 +02:00
NiccoloN e8a08f6dd0 faster pim VerificationPass.cpp and pim code emission
Validate Operations / validate-operations (push) Has been cancelled
2026-05-25 15:24:12 +02:00
NiccoloN 4855a2e105 add verification of static weights in spatial
Validate Operations / validate-operations (push) Has been cancelled
2026-05-24 12:00:42 +02:00
NiccoloN 3a7a832198 MaterializeMergeSchedule.cpp fix for yolo11_depth_18 2026-05-24 11:54:00 +02:00
NiccoloN 48ca6bd28d speed fix with a simple cache
Validate Operations / validate-operations (push) Has been cancelled
2026-05-24 10:52:28 +02:00
NiccoloN f595cc6ffd fix high memory usage in IR 2026-05-24 10:41:47 +02:00
NiccoloN c734f1b37e better MaterializeMergeSchedule.cpp that emits much more compact IR
Validate Operations / validate-operations (push) Has been cancelled
add support for other constant-time arith ops in codegen
2026-05-24 10:10:24 +02:00
NiccoloN b79ce8eeaa use affine dialect to express simple constant progressions
Validate Operations / validate-operations (push) Has been cancelled
run dce at the end of MaterializeMergeSchedule to get rid of unused constants
2026-05-23 14:25:34 +02:00
NiccoloN 76a37e198f better MaterializeMergeSchedule.cpp with both send and receive compaction in for loops
Validate Operations / validate-operations (push) Has been cancelled
2026-05-23 11:17:36 +02:00
NiccoloN 7f3c7464b4 update cost model of batch lanes to consider only a slice of the shared batch input
Validate Operations / validate-operations (push) Has been cancelled
2026-05-22 22:16:19 +02:00
NiccoloN c77ffa9c56 better MaterializeMergeSchedule.cpp with %lane indexed batch computes
support for tensors of index values
2026-05-22 21:52:28 +02:00
NiccoloN 495186503c fix cmake magic once again 2026-05-22 19:21:56 +02:00
NiccoloN 2c1da813b5 fix much stuff 2026-05-22 18:53:38 +02:00
NiccoloN 8337a11ce9 automatic code reformat 2026-05-22 15:23:48 +02:00
ilgeco d136136d22 Fix add of input in random order for compute_batch
Validate Operations / validate-operations (push) Has been cancelled
2026-05-22 15:21:02 +02:00
NiccoloN 074eb183c7 saner SpatialToPimPass architecture
Validate Operations / validate-operations (push) Has been cancelled
2026-05-22 07:27:54 +02:00
NiccoloN 43ed3914b8 better MaterializeMergeSchedule.cpp (something still broken downstream)
Validate Operations / validate-operations (push) Has been cancelled
2026-05-22 06:56:39 +02:00
ilgeco 6aaf1c0870 Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone
Validate Operations / validate-operations (push) Has been cancelled
2026-05-21 14:44:19 +02:00
ilgeco fe35b3ed43 Equivalent Class but broken 2026-05-21 14:43:59 +02:00
NiccoloN 90a9339686 better cmake to keep IDEs analyses happy
Validate Operations / validate-operations (push) Has been cancelled
2026-05-21 14:13:54 +02:00
NiccoloN a50e77ff38 refactorone
Validate Operations / validate-operations (push) Has been cancelled
2026-05-20 19:06:41 +02:00
NiccoloN f56c4159b5 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor 2026-05-19 15:01:26 +02:00
ilgeco 5637c861b4 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor
Validate Operations / validate-operations (push) Has been cancelled
2026-05-19 15:00:11 +02:00
ilgeco 94157a8404 Very big timeout 2026-05-19 14:53:34 +02:00
ilgeco 68a3521978 Perft topological fix 2026-05-19 14:52:54 +02:00
NiccoloN a103ba328b remove dead logic 2026-05-19 12:23:01 +02:00
NiccoloN e263e05f56 remove dead logic
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 18:32:40 +02:00
ilgeco 34c29fdec4 Materialize modification
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 17:22:13 +02:00
ilgeco aa088e2ba5 Verify fix
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 17:20:40 +02:00
NiccoloN 2836e759ab remove useless file
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 14:51:03 +02:00
NiccoloN 8071ebab0b faster refactored merge pass
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 14:50:19 +02:00
NiccoloN f1602c0550 add peft scheduling
Validate Operations / validate-operations (push) Has been cancelled
better deadlock report by pim simulator
2026-05-18 12:09:27 +02:00
NiccoloN de0a2f4561 remove useless guard in gemm lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-05-15 18:22:13 +02:00
NiccoloN 1c4a5bde76 compact softmax op lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-05-15 18:14:59 +02:00
NiccoloN 78242e2887 compact resize op lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-05-15 17:36:12 +02:00
NiccoloN fe244d5aa1 new ops tests for matmul, grouped conv, concat and reshape
Validate Operations / validate-operations (push) Has been cancelled
related fixes
2026-05-14 15:54:06 +02:00
NiccoloN d09e76c8f9 fix matmul rewriting/lowering
Validate Operations / validate-operations (push) Has been cancelled
fix reshape lowering
add support for grouped-convolution lowering
quieter verifier with capped error messages
2026-05-14 14:09:30 +02:00
NiccoloN c5e608fa5b replace greedy pattern rewrites with partial conversions
Validate Operations / validate-operations (push) Has been cancelled
better failure messages
2026-05-14 11:48:16 +02:00
ilgeco 43f3ccdd21 new yolo nodes with 100% more statics
Validate Operations / validate-operations (push) Has been cancelled
2026-05-14 10:47:31 +02:00
NiccoloN 8d95c604a6 automatic code formatting
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 21:51:19 +02:00
NiccoloN 55eda487dc use seed in validate.py for deterministic tests 2026-05-13 21:49:36 +02:00
NiccoloN 061139aefb fix wrong send/receive reordering in post dcp merge instructions compaction 2026-05-13 21:48:49 +02:00
431 changed files with 23379 additions and 12999 deletions
+2 -5
View File
@@ -4,14 +4,11 @@
.claude .claude
.codex .codex
AGENTS.md
CMakeUserPresets.json CMakeUserPresets.json
build build_*
build_release
cmake-build-debug
cmake-build-release
compile.sh compile.sh
pimcomp_utils/*
**/__* **/__*
+210
View File
@@ -0,0 +1,210 @@
* Always read the full README.md before doing anything
* Build commands:
* `cmake --build ./build_release`
* `cmake --build ./build_debug`
* Never use `ninja` directly: it bypasses cmake's configuration and invalidates the build cache
* Always try the release build first before building with the debug version
* Use the debug build only when it is useful to obtain a clear stack trace with symbols, inspect names, place breakpoints, or test a small case interactively
* The debug build is very slow, so use it only on small fast tests such as operation validations, not on network validations
# Core engineering philosophy
* Clean architecture matters as much as making the immediate test pass
* Prefer fixes that preserve clear ownership boundaries, explicit invariants, and simple dataflow
* Do not stack compensating fixes on top of earlier mistakes. If the current approach is becoming messy, stop and explain why
* A correct fix should usually make the responsible producer, resolver, verifier, or lowering own the behavior directly
* Avoid late repair passes, defensive cleanup, or broad rewrites when a cleaner owner-side fix is possible
* Do not hide an upstream modeling bug by normalizing it later in the pipeline. Fix the producer when the producer owns the invariant
* Prefer patterns/rewrites for local IR canonicalization. Use module walks only when pass-level structural analysis genuinely requires them
* Prefer compact, structured designs over long case-by-case implementations
# Think before coding
* State assumptions explicitly before implementing when they affect the design
* If multiple interpretations exist, present them instead of silently choosing one
* If a simpler approach exists, say so and prefer it unless there is a clear reason not to
* If something is unclear, stop, name what is confusing, and ask
* If the requested or obvious approach would make the architecture worse, push back and propose a cleaner alternative
# Code changes
* Keep changes minimal and localized to the relevant parts of the code
* Preserve the existing naming conventions and coding style used in the surrounding code
* Keep code easy to read, well organized, and suitable for future extensibility
* A function must not exceed roughly 200/250 lines. If a change pushes a function beyond that, extract focused helpers
* Prefer clear naming and structure over comments. Add comments only when they materially improve clarity
* Do not rename symbols, move files, or restructure modules unless that is necessary for the requested change
* Avoid duplicate ad-hoc logic. If the same concept appears in multiple places, consider whether it deserves a shared helper/API
* When adding a helper or API, ask:
* Could this be useful to another component now
* Is another component already implementing the same idea differently
* Is this likely to be needed by a future adjacent component
* What is the narrowest useful abstraction
* What is the correct ownership level for this API
* If a shared API is justified, place it at the lowest clean layer that can be used by all relevant consumers without creating dependency cycles or leaking policy across layers
* If an existing component should use a newly introduced shared API, refactor that component in the same patch when doing so is directly related and reduces duplication
* Do not create broad frameworks just because a helper might someday be useful. Shared APIs should encode a real reusable concept, not speculative generality
* If the reusable abstraction is plausible but not clearly needed yet, keep the code local and mention the possible future extraction separately
# Avoid case-listing designs
* Avoid solving problems with large chains of `if`/`else`, switches, or repeated special cases that enumerate every possible situation
* Long case listings tend to overfit the current tests, grow the codebase, and hide the underlying abstraction
* When you see a growing list of special cases, stop and look for the shared concept, data model, interface, or normalization step that would make the cases collapse
* Prefer table-driven logic, traits/interfaces, small reusable predicates, structured dispatch, or producer-side normalization when they express the invariant more directly
* A few explicit cases are fine when the domain is genuinely small and closed
* If the list is likely to grow, refactor toward a cleaner and more compact design instead of adding another branch
* When keeping a case list is the pragmatic choice, explain why the domain is closed or why a broader abstraction would be premature
# Ownership and invariants
Before implementing, identify the owner of the behavior:
* A producer should emit IR/data that satisfies the contract of the next stage
* A lowering should make representation changes explicit and semantically correct
* A resolver should resolve existing structure without silently changing semantics
* A verifier should reject invalid states with bounded, actionable diagnostics
* Codegen should assume verified invariants and fail clearly if they are violated
When fixing a bug:
* State the invariant that was violated
* State which component should own that invariant
* Fix that component directly
* Avoid fixes that merely mask the violation later in the pipeline
* Add or preserve verification if the invariant is important enough to regress
# Refactor and API policy
You may propose or implement a refactor when:
* the local fix would duplicate logic
* the local fix would violate a layer boundary
* the bug exists because responsibility is assigned to the wrong component
* multiple components already implement ad-hoc variants of the same concept
* a shared helper/API would make the code smaller, clearer, and easier to maintain
* existing callers can be migrated cleanly without broad churn
* the current implementation is turning into a long list of special cases instead of a structured solution
When proposing or implementing a refactor:
* Explain what responsibility is being moved or shared
* Justify why the new location is the right ownership level
* Keep the API narrow and named after the concept or invariant it represents
* Migrate directly related existing users when that improves compactness and consistency
* Separate changes required for correctness from optional cleanup
* Avoid unrelated renames, formatting changes, or module moves
* Do not expand a justified refactor beyond directly related callers
Do not refactor when:
* the issue is truly local and a local fix is clearer
* the abstraction would have only one user and no clear adjacent use
* the abstraction would mix policies from different layers
* the refactor would affect unrelated behavior
* the refactor is mainly aesthetic
# Working style
* Infer style and conventions from the existing code before introducing new patterns
* When several implementation options are possible, prefer the simplest one that fits the current architecture and minimizes churn
* Push back when the requested or obvious fix would make the architecture worse
* If a cleaner fix requires a small refactor or shared helper/API, propose it explicitly and justify it
* Avoid broad refactors unless explicitly requested or clearly necessary for correctness and maintainability
* When tests fail, bucket failures by likely root cause and separate patch-related failures from pre-existing or out-of-scope failures
# Simplicity first
* Minimum code that solves the problem cleanly. Nothing speculative
* No features beyond what was asked
* No error handling for impossible scenarios
* If you write 200 lines and it could be 50, rewrite it
* Ask: “Would a senior engineer say this is overcomplicated?” If yes, simplify
* Prefer direct, explicit code over generic machinery unless the generic machinery clearly reduces duplication and preserves boundaries
# Fallbacks and defaults
* Avoid silent fallback behavior when the semantic category is unknown
* Do not treat “unknown” as “safe” unless the codebase already defines that convention
* If a value cannot be classified, either preserve the existing behavior deliberately or fail with a clear diagnostic
* When adding a fallback, state why it is semantically valid and what invariant makes it safe
# Surgical changes
* Touch only what you must
* Clean up only the mess introduced by your own change
* Do not “improve” adjacent code, comments, or formatting
* Match existing style, even if you would personally do it differently
* If you notice unrelated dead code, bad abstractions, or fragile design, mention it separately. Do not delete or rewrite it unless asked
* When your changes create orphans, remove imports, variables, functions, or files made unused by your change
* Every changed line should trace directly to the requested fix, a required cleanup, or a justified reuse/refactor decision
# Diagnostics and verification
* Use existing bounded diagnostic mechanisms for pass-level verification or codegen failures
* Do not emit unbounded repeated diagnostics from loops or parallel workers
* Diagnostics should identify the violated invariant and the relevant value/op when useful
* Verifiers should reject invalid states, not repair them
* Codegen should not compensate for invalid IR/data unless codegen is the owner of that invariant
* Do not make failing tests pass by weakening verifiers, assertions, or diagnostics unless the check itself is proven wrong
* If a check is too strict, explain the valid case it rejects and update the invariant accordingly
* Prefer fixing invalid IR/data producers over relaxing consumers
* If adding diagnostics only for debugging, remove them or cap them before finalizing
# Temporary debugging code
* Temporary diagnostics, dumps, assertions, and debug-only helpers must be removed or intentionally converted into bounded permanent diagnostics before finalizing
* If debug instrumentation remains, explain why it is useful as permanent infrastructure
* Do not leave noisy validation output behind
# Performance awareness
* Avoid algorithmic regressions in compiler passes, especially repeated full-module walks, repeated expensive analyses, or per-op recomputation inside nested loops
* If a change adds a walk, cache, analysis, or structural traversal, justify why it is needed
* For hot paths, prefer preserving existing asymptotic behavior unless a better structure is part of the requested change
* If performance may change, mention the expected impact and suggest a targeted timing check
# Goal-driven execution
For multi-step tasks, state a brief plan:
1. [Step] → verify: [check]
2. [Step] → verify: [check]
3. [Step] → verify: [check]
Define success criteria before implementing:
* For bug fixes, success means reproducing or identifying the failure, fixing the responsible owner, and verifying the targeted case
* For refactors, success means preserving behavior while making ownership, reuse, or structure cleaner
* For validation changes, success means checking both valid and invalid cases when applicable
Transform tasks into verifiable goals:
* “Fix the bug” → identify the invariant, reproduce the failure, fix the owner, verify the targeted case
* “Add validation” → write or identify tests for invalid inputs, then make them pass/fail as expected
* “Refactor X” → preserve behavior before and after, then run relevant tests
# Final self-review
Before reporting completion, check:
* Did I fix the owner of the invariant rather than masking the issue downstream
* Did I avoid broad case lists and ad-hoc special handling
* Did I introduce a helper/API only at the right ownership level
* Did I migrate directly related duplicate logic when doing so improves compactness
* Did I avoid weakening verifiers or assertions unnecessarily
* Did I remove temporary debugging code or make it bounded and intentional
* Did I avoid unrelated formatting, renames, or cleanup
* Did I consider performance impact for added walks, analyses, caches, or repeated computations
* Did I run the required build/test commands
* Did I clearly report remaining failures or risks
When reporting back:
* Say what changed
* Say what was verified
* Say what remains
* When showing code in chat, make it easy to copy-paste into the codebase
* Keep outputs focused on the changed parts
* List bad practices, fragile assumptions, or cleaner alternatives separately
* If a change is intentionally pragmatic rather than architecturally ideal, say so and explain the tradeoff
+92 -24
View File
@@ -3,31 +3,99 @@ cmake_minimum_required(VERSION 3.20.0)
project(raptor) project(raptor)
# Add symlink to PIM as accelerator in onnx-mlir # Materialize a CMake shim directory
function(raptor_ensure_symlink link_path target_path) function(raptor_write_external_cmake_shim shim_dir external_source_dir description)
get_filename_component(link_parent "${link_path}" DIRECTORY) get_filename_component(real_external_source_dir "${external_source_dir}" REALPATH)
file(RELATIVE_PATH relative_external_source_dir "${shim_dir}" "${real_external_source_dir}")
if(NOT EXISTS "${link_parent}") if (NOT EXISTS "${real_external_source_dir}/CMakeLists.txt")
message(FATAL_ERROR "Directory not found: ${link_parent}") message(FATAL_ERROR
endif() "External CMake source directory not found or missing CMakeLists.txt:\n"
" ${real_external_source_dir}"
if(NOT EXISTS "${link_path}")
message(STATUS "Creating symlink ${link_path} -> ${target_path}")
file(CREATE_LINK
"${target_path}"
"${link_path}"
SYMBOLIC
) )
endif() endif ()
if (IS_SYMLINK "${shim_dir}")
message(STATUS "Removing old full-directory symlink: ${shim_dir}")
file(REMOVE "${shim_dir}")
endif ()
if (EXISTS "${shim_dir}" AND NOT IS_DIRECTORY "${shim_dir}")
message(FATAL_ERROR "Expected directory or absent path, got file: ${shim_dir}")
endif ()
file(MAKE_DIRECTORY "${shim_dir}")
set(shim_file "${shim_dir}/CMakeLists.txt")
set(shim_contents
"get_filename_component(raptor_external_source_dir
\"\${CMAKE_CURRENT_LIST_DIR}/${relative_external_source_dir}\"
REALPATH
)
add_subdirectory(
\"\${raptor_external_source_dir}\"
\"\${CMAKE_CURRENT_BINARY_DIR}/raptor-external\"
)
if (DEFINED PIM_ENABLED)
set(PIM_ENABLED \"\${PIM_ENABLED}\" PARENT_SCOPE)
endif ()
"
)
if (EXISTS "${shim_file}")
file(READ "${shim_file}" old_contents)
else ()
set(old_contents "")
endif ()
if (NOT old_contents STREQUAL shim_contents)
file(WRITE "${shim_file}" "${shim_contents}")
message(STATUS "Wrote CMake shim for ${description}: ${shim_file}")
else ()
message(STATUS "CMake shim already up to date for ${description}")
endif ()
# Mirror the external tree's first-level entries into the shim directory
# so legacy includes like src/Accelerators/PIM/Compiler/... keep working.
file(GLOB children RELATIVE "${real_external_source_dir}" "${real_external_source_dir}/*")
foreach (child IN LISTS children)
if (child STREQUAL "CMakeLists.txt")
continue()
endif ()
set(real_child "${real_external_source_dir}/${child}")
set(shim_child "${shim_dir}/${child}")
if (IS_SYMLINK "${shim_child}")
file(READ_SYMLINK "${shim_child}" existing_link_target)
if (existing_link_target STREQUAL real_child)
continue()
endif ()
file(REMOVE_RECURSE "${shim_child}")
elseif (EXISTS "${shim_child}")
# Do not delete real files/directories. This protects the generated shim.
continue()
endif ()
file(CREATE_LINK
"${real_child}"
"${shim_child}"
SYMBOLIC
)
endforeach ()
endfunction() endfunction()
raptor_ensure_symlink( raptor_write_external_cmake_shim(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM" "${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM" "${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
"PIM accelerator"
) )
raptor_ensure_symlink(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM" raptor_write_external_cmake_shim(
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM" "${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
"PIM accelerator tests"
) )
# Patch onnx-mlir sources for PIM accelerator support. # Patch onnx-mlir sources for PIM accelerator support.
@@ -38,21 +106,21 @@ function(raptor_apply_patch file_path anchor replacement description)
# Already applied replacement text is present # Already applied replacement text is present
string(FIND "${contents}" "${replacement}" already_applied_pos) string(FIND "${contents}" "${replacement}" already_applied_pos)
if(NOT already_applied_pos EQUAL -1) if (NOT already_applied_pos EQUAL -1)
message(STATUS "Patch already applied: ${description}") message(STATUS "Patch already applied: ${description}")
return() return()
endif() endif ()
# Anchor must exist for the patch to be applicable # Anchor must exist for the patch to be applicable
string(FIND "${contents}" "${anchor}" anchor_pos) string(FIND "${contents}" "${anchor}" anchor_pos)
if(anchor_pos EQUAL -1) if (anchor_pos EQUAL -1)
message(FATAL_ERROR message(FATAL_ERROR
"Patch anchor not found onnx-mlir may have changed.\n" "Patch anchor not found onnx-mlir may have changed.\n"
" Patch : ${description}\n" " Patch : ${description}\n"
" File : ${file_path}\n" " File : ${file_path}\n"
" Anchor: ${anchor}" " Anchor: ${anchor}"
) )
endif() endif ()
string(REPLACE "${anchor}" "${replacement}" patched "${contents}") string(REPLACE "${anchor}" "${replacement}" patched "${contents}")
file(WRITE "${file_path}" "${patched}") file(WRITE "${file_path}" "${patched}")
+252 -154
View File
@@ -1,223 +1,321 @@
# Raptor # Raptor
Raptor is a domain-specific MLIR compiler for neural networks (ONNX format) Raptor is a domain-specific MLIR compiler for neural networks in ONNX format,
targeting in-memory computing / processing-in-memory (PIM) architectures. targeting in-memory computing / processing-in-memory (PIM) architectures. It
It progressively lowers ONNX-MLIR through a set of MLIR dialects down to extends ONNX-MLIR with a PIM accelerator and progressively lowers ONNX-MLIR
target-specific artifacts (currently JSON code for the `pimsim-nn` simulator). through custom MLIR dialects to simulator artifacts.
The current target is the PIM simulator stack under `backend-simulators/pim`.
Raptor emits binary per-core `.pim` instruction files by default, plus
`memory.bin`, `config.json`, and weight binaries. It can also emit per-core JSON
instruction files with `--pim-emit-json`.
## Overview ## Overview
PIM architectures perform most of the computation directly in memory. PIM architectures perform most computation directly in memory. The supported
Raptor's first supported target is `pimsim-nn`, which simulates a chip with: target models a chip with:
- a shared host memory, - shared host memory,
- a number of cores that do most of the computation directly in their memory - multiple PIM cores,
(vector ops, vmm/mvm on ReRAM crossbars), - ReRAM crossbars for vector-matrix / matrix-vector work,
- no branching instructions (branchless architecture) and no hardware loop - explicit communication between cores,
support — any repeated work (e.g. convolutions) must be unrolled into - no hardware branch or loop support in emitted simulator code.
explicit per-iteration instructions.
Because of this, the amount of emitted instructions explodes quickly and the Because repeated work such as convolutions is eventually made explicit, emitted
compiler must optimize aggressively at every stage to keep compilation instruction counts can grow quickly. Most compiler work therefore focuses on
tractable. lowering, scheduling, memory layout, and code-generation optimizations.
A second target, `PulPim`, is planned for an accelerator with RISC-V cores
each carrying its own in-memory computing unit and crossbars. It will live in
a dedicated dialect (future work).
### Targets and simulators ### Targets and simulators
`pimsim-nn` (under `backend-simulators/pim/pimsim-nn`) is used for - `backend-simulators/pim/pim-simulator` is the in-tree Rust functional
**performance** estimates (latency, energy), but does not functionally execute simulator used by validation. It reads Raptor's `pim/` artifact directory and
the JSON code it consumes. To validate the numerical correctness of the JSON compares simulator output against native ONNX-MLIR execution.
code produced by Raptor (or, for comparison, by the `pimcomp` compiler), we use - `backend-simulators/pim/pimsim-nn` is the performance simulator submodule.
a Rust simulator we maintain in-tree at The helper scripts in `pimcomp_utils/` are for comparison with PIMCOMP-NN and
`backend-simulators/pim/pim-simulator`. contain local paths; treat them as local utilities, not portable workflows.
## Compilation pipeline ## Compilation pipeline
The PIM-related sources live under `src/PIM` and the tests under `test/PIM`. The PIM sources live under `src/PIM` and tests under `test/PIM`. CMake exposes
When working on this codebase, most changes should stay confined to those them to ONNX-MLIR through generated shim directories under
trees (you only need to look outside, e.g. at `onnx-mlir` or `llvm`, for `onnx-mlir/src/Accelerators/PIM` and `onnx-mlir/test/accelerators/PIM`.
framework-level details).
High-level lowering flow: High-level lowering flow:
``` ```
ONNX-MLIR ──► Spatial ──► Pim (tensor) ──► Pim (bufferized) ──► PIM JSON ONNX-MLIR -> Spatial -> Pim (tensor) -> Pim (bufferized) -> PIM artifacts
``` ```
1. **ONNX Spatial** (`src/PIM/Conversion/ONNXToSpatial`). 1. **ONNX -> Spatial** (`src/PIM/Conversion/ONNXToSpatial`).
Lowers ONNX ops into the `spat` dialect (`src/PIM/Dialect/Spatial`). Lowers supported ONNX ops into the `spat` dialect
Spatial models a high-level spatial in-memory accelerator: vmm/mvm (`src/PIM/Dialect/Spatial`). Conversion patterns are split by op family under
operations are accelerated by storing a constant RHS matrix into a `Patterns/{Math,NN,Tensor}` and currently cover Conv, Gemm, MatMul,
crossbar. Crossbars cannot be re-programmed during execution, have a elementwise Add/Mul/Div, ReduceMean, pooling, Relu, Sigmoid, Softmax,
limited fixed size, and there is a limited number of them per core. Concat, Gather, Reshape, Resize, and Split.
Conversion patterns are split by op family under
`Conversion/ONNXToSpatial/Patterns/{Math,NN,Tensor}` (Conv, Gemm, MatMul,
Elementwise, ReduceMean, Pool, Relu, Sigmoid, Softmax, Concat, Gather,
Reshape, Resize, Split).
2. **Spatial → Pim** (`src/PIM/Conversion/SpatialToPim`). 2. **Merge compute nodes**
Lowers Spatial to the `pim` dialect (`src/PIM/Dialect/Pim`), which (`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`).
materializes PIM cores (`pim.core`), inter-core communication Builds a compute graph, schedules it with the PEFT scheduler, and materializes
(`pim.send` / `pim.receive`), halts, and crossbar-level operations. the merge schedule into Spatial IR. Supporting scheduling code lives under
`MergeComputeNodes/Scheduling`.
3. **Merge compute nodes** (`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`). 3. **Spatial -> Pim** (`src/PIM/Conversion/SpatialToPim`).
A DCP-inspired heuristic (Dynamic Critical Path — see the original Lowers Spatial operations to the `pim` dialect (`src/PIM/Dialect/Pim`),
scheduling paper by Kwok & Ahmad, including `pim.core`, `pim.core_batch`, communication, tensor packing, global
[DCP-eScience2007](https://clouds.cis.unimelb.edu.au/papers/DCP-eScience2007.pdf)) tensor materialization, and return-path normalization.
that coarsens the virtual node graph and decides how to group compute
nodes onto cores. Our implementation is only DCP-*inspired*: it is a
heuristic with different assumptions from the paper (different cost
model, constraints from crossbar capacity / core resources, and a
windowed coarsening loop instead of full-graph reprioritization). The
`dcp-critical-window-size` option controls how many lowest-slack virtual
nodes each coarsening iteration considers (0 = legacy full-graph
analysis). Related sources: `DCPGraph/DCPAnalysis.cpp`, `Graph.cpp/.hpp`,
`MergeComputeNodesPass.cpp`.
4. **Bufferization** (`src/PIM/Dialect/Pim/Transforms/Bufferization`). 4. **Bufferization** (`src/PIM/Dialect/Pim/Transforms/Bufferization`).
Converts tensor-semantics PIM IR into memref-semantics PIM IR using the Converts tensor-semantics PIM IR into memref-semantics PIM IR using MLIR's
standard MLIR `BufferizableOpInterface` machinery bufferization interfaces.
(`OpBufferizationInterfaces.*`, `PimBufferization.td`).
5. **Static memory coalescing** (`src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing`). 5. **Static memory coalescing**
Conservatively reuses same-typed local memref allocations inside PIM cores (`src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing`).
after bufferization and before code generation. Reuses compatible local memref allocations inside PIM cores before codegen.
6. **PIM code generation** (`src/PIM/Pass/PimCodegen`): 6. **PIM code generation** (`src/PIM/Pass/PimCodegen` and
- `HostConstantFolding` — folds host-side constants. `src/PIM/Compiler`).
- `MaterializeHostConstantsPass` materializes the remaining host Folds host constants, materializes remaining host constants, verifies PIM IR,
constants for emission. emits `.pim` core files, writes weights, and writes `memory.bin` /
- `VerificationPass` — checks invariants before emission. `config.json`.
- `EmitPimJsonPass` — emits the final PIM JSON consumed by `pimsim-nn`
and `pim-simulator`.
Supporting pieces: Supporting pieces:
- `src/PIM/Compiler` — PIM-specific compiler options (crossbar size/count, - `src/PIM/Common` - shared IR, filesystem, diagnostics, reports, and utility
core count, DCP window, experimental conv impl, concat error handling, …) helpers.
and `PimCodeGen` entry points. - `src/PIM/Compiler` - PIM compiler options, memory/address planning, binary
- `src/PIM/Common` — shared utilities (`PimCommon`, `LabeledList`). instruction format, artifact writing, weight emission, and codegen entry
- `src/PIM/Pass` — auxiliary passes (`MessagePass`, `CountInstructionPass`) points.
and the `PIMPasses.h` registry used by `PimAccelerator`. - `src/PIM/Conversion/SpatialToGraphviz` - optional Spatial graphviz conversion
- `src/PIM/PimAccelerator.{cpp,hpp}` — accelerator entry point: registers pass.
dialects, passes, and plugs Raptor into the ONNX-MLIR driver. - `src/PIM/Pass` - pass registration and auxiliary passes.
- `src/PIM/PimAccelerator.{cpp,hpp}` - ONNX-MLIR accelerator entry point.
## Key compiler options ## Key compiler options
Pass these on the `onnx-mlir` command line when compiling for PIM: Pass these to `onnx-mlir` when compiling for PIM:
- `--maccel=PIM` select the PIM accelerator. - `--maccel=PIM` - select the PIM accelerator.
- `--EmitSpatial` / `--EmitPim` / `--EmitPimBufferized` / `--EmitPimCodegen` - `--EmitSpatial`, `--EmitPim`, `--EmitPimBufferized`,
stop the pipeline at the requested stage (default: `EmitPimCodegen`). `--EmitPimCodegen` - stop the PIM pipeline at the requested stage. The PIM
- `--pim-only-codegen` — assume the input is already bufferized PIM IR and default is `--EmitPimCodegen`.
run only the codegen tail. - `--core-count=<N>` - required positive core count for PIM compilation.
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and - `--crossbar-size=<N>` - crossbar width/height. Default in code is `2`.
per-core count. - `--crossbar-count=<N>` - crossbars per core. Default in code is `256`.
- `--core-count=<N>` — number of cores (`-1` picks the minimum). - `--pim-merge-scheduler=peft` - merge scheduler. `peft` is the only accepted
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy). value in the current code.
- `--use-experimental-conv-impl` — alternative convolution lowering. - `--pim-only-codegen` - assume input is already bufferized PIM IR and only run
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`. the codegen tail.
- `--pim-emit-json` - also emit `core_*.json` instruction files alongside
`core_*.pim`.
- `--use-experimental-conv-impl` - use the alternate convolution lowering.
- `--ignore-concat-error` - soft-fail a ConcatOp corner case.
Example:
```bash
./build_release/Release/bin/onnx-mlir model.onnx -o /tmp/raptor/model \
--maccel=PIM --EmitPimCodegen \
--crossbar-size=2048 --crossbar-count=256 --core-count=1000
```
This writes PIM artifacts under `/tmp/raptor/pim/`.
## Validation ## Validation
Functional validation lives in `validation/` and drives the Rust Functional validation lives in `validation/`. It compiles ONNX models, builds a
`pim-simulator` to compare Raptor's output against a reference. native ONNX-MLIR reference runner, generates random inputs, runs Raptor, runs
the Rust PIM simulator, and compares outputs.
Per-operation validation (from `validation/`): Python dependencies used by the validation scripts are `numpy`, `onnx`, and
`colorama`. The simulator requires the Rust toolchain.
``` Per-operation validation from the repository root:
validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \ ```bash
--onnx-include-dir ../onnx-mlir/include python3 validation/validate.py \
--raptor-path build_release/Release/bin/onnx-mlir \
--onnx-include-dir onnx-mlir/include \
--core-count 1000
``` ```
End-to-end network validation (example: first 4 layers of YOLOv11n): Validate one network or a subset by pointing `--operations-dir` at any directory
containing `.onnx` files:
``` ```bash
validate.py \ python3 validation/validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \ --raptor-path build_release/Release/bin/onnx-mlir \
--onnx-include-dir ../onnx-mlir/include \ --onnx-include-dir onnx-mlir/include \
--operations-dir ./networks/yolo11n/depth_04 \ --operations-dir validation/networks/yolo11n/depth_04 \
--crossbar-size 2048 --crossbar-count 256 --core-count 1000 --crossbar-size 2048 --crossbar-count 256 --core-count 1000
``` ```
Available networks under `validation/networks/`: `vgg16`, `yolo11n`. Useful validation options:
Available operations under `validation/operations/`: `add`, `conv`, `div`, - `--simulator-dir <path>` - override the auto-detected
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`, `backend-simulators/pim/pim-simulator` path.
`sigmoid`, `softmax`, `split`. - `--threshold <float>` - maximum allowed per-element output difference.
- `--seed <int>` - RNG seed for generated inputs.
- `--command-timeout-seconds <float>` - timeout for compiler, runner, and
simulator subprocesses.
- `--verbose` - print subprocess logs and average PIM pass timings.
- `--clean` - remove generated validation artifacts and exit.
## Rebuilding Each validation run writes artifacts in the model workspace, for example under
`validation/operations/gemm/small/`:
- `inputs/` - generated input CSV files.
- `outputs/` - native ONNX-MLIR reference outputs.
- `raptor/` - compiler artifacts, including `*.onnx.mlir`, dialect dumps under
`dialects/`, reports under `reports/`, and final PIM artifacts under `pim/`.
- `runner/` - generated reference runner source, build tree, and shared library.
- `simulation/out.bin` - raw simulator output used for comparison.
Release build (fast): The compiler currently dumps dialect snapshots such as `spatial0.mlir`,
`spatial1_dcp_merged.mlir`, `pim0.mlir`, `pim1_buff.mlir`,
`pim2_coalesced.mlir`, and `pim3_folded.mlir` when an output directory is
available.
``` To rerun the simulator manually with tracing after validation has produced a
cmake --build /home/nico/raptor/raptor/cmake-build-release --target onnx-mlir -j 30 `raptor/pim/` directory:
```bash
cd backend-simulators/pim/pim-simulator
cargo run --no-default-features --features tracing --release \
--package pim-simulator --bin pim-simulator -- \
-f /path/to/workspace/raptor/pim \
-o /path/to/workspace/simulation/out.bin \
-d <addr0>,<size0>,<addr1>,<size1>,...
``` ```
A slower debug build is also available — configure it the same way but with With `--features tracing`, the simulator writes per-core traces as
`-DCMAKE_BUILD_TYPE=Debug` (see installation instructions below). `TraceCore0`, `TraceCore1`, ... next to `out.bin`. The validator normally
computes the `-d` ranges from `raptor/pim/config.json` and model output shapes.
Available validation networks under `validation/networks/`: `vgg16`,
`yolo11n`, `yolo11nv2`.
Available operation suites under `validation/operations/`: `add`, `concat`,
`conv`, `div`, `gather`, `gemm`, `gemv`, `matmul`, `mul`, `pool`,
`reduce_mean`, `relu`, `reshape`, `resize`, `sigmoid`, `softmax`, `split`.
Generated operation tests can be regenerated with:
```bash
python3 validation/operations/gen_tests.py
```
## Build ## Build
Initialize submodules first:
```bash
git submodule update --init --recursive
```
The project follows ONNX-MLIR's build requirements. The CI workflow documents
the currently used versions and setup:
- CMake 4.3.0 in CI,
- LLVM/MLIR checked out under `onnx-mlir/llvm-project`,
- Protobuf `v34.0`,
- Rust stable for `pim-simulator`,
- Python packages `numpy`, `onnx`, `colorama` for validation.
### Protobuf ### Protobuf
Use the following commands to install protobuf: Install Protobuf if your system does not already provide a compatible version:
```
```bash
git clone --depth 1 --branch v34.0 https://github.com/protocolbuffers/protobuf git clone --depth 1 --branch v34.0 https://github.com/protocolbuffers/protobuf
cd protobuf cmake -S protobuf -B protobuf/build -G Ninja \
mkdir build -DCMAKE_BUILD_TYPE=Release \
cd build -Dprotobuf_BUILD_TESTS=OFF
cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release cmake --build protobuf/build
ninja sudo cmake --install protobuf/build
sudo ninja install
``` ```
You can now remove the protobuf repo directory with: You can then remove the temporary checkout:
```
cd ../.. ```bash
rm -rf protobuf rm -rf protobuf
``` ```
### Mlir ### MLIR
Follow the first part of instructions [here](onnx-mlir/docs/BuildOnLinuxOSX.md) to build mlir. Follow the ONNX-MLIR instructions in
`onnx-mlir/docs/BuildOnLinuxOSX.md` to build LLVM/MLIR. The local Raptor build
expects `MLIR_DIR` to point at the MLIR CMake package, for example:
Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor ```bash
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build_release/lib/cmake/mlir
Moreover, if compiling with build type debug, it is also suggested to use
mold as linker (you will need to install it if you don't have it already)
to reduce memory usage during linking. You can use it by setting the options:
```
-DLLVM_USE_LINKER=mold
``` ```
If your LLVM build directory is named `build` instead of `build_release`, adjust
the path accordingly.
### Raptor ### Raptor
Use the following commands to build Raptor. Configure a release build:
Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor. ```bash
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build_release/lib/cmake/mlir
Also in this case, it is suggested to use mold as linker to reduce link time and memory usage, cmake -S . -B build_release -G Ninja \
setting the options:
```
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_MODULE_LINKER_FLAGS="-fuse-ld=mold"
```
```
git submodule update --init --recursive
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build/lib/cmake/mlir
mkdir build && cd build
cmake .. -G Ninja \
-DCMAKE_BUILD_TYPE=Release \ -DCMAKE_BUILD_TYPE=Release \
-DONNX_MLIR_ACCELERATORS=PIM \ -DONNX_MLIR_ACCELERATORS=PIM \
-DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_DIR=${MLIR_DIR} -DMLIR_DIR=${MLIR_DIR}
cmake --build .
``` ```
If the build fails because of protobuf missing uint definitions, Configure a debug build similarly:
just patch the problematic files by adding ```#include <cstdint>``` to their includes.
```bash
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build_debug/lib/cmake/mlir
cmake -S . -B build_debug -G Ninja \
-DCMAKE_BUILD_TYPE=Debug \
-DONNX_MLIR_ACCELERATORS=PIM \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_DIR=${MLIR_DIR}
```
For debug development, using `mold` can reduce link time and memory use:
```bash
cmake -S . -B build_debug -G Ninja \
-DCMAKE_BUILD_TYPE=Debug \
-DONNX_MLIR_ACCELERATORS=PIM \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_DIR=${MLIR_DIR} \
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_MODULE_LINKER_FLAGS="-fuse-ld=mold"
```
Build the compiler with CMake:
```bash
cmake --build ./build_release
cmake --build ./build_debug
```
Do not invoke `ninja` directly for this project; use `cmake --build` so CMake's
configuration and generated shims stay consistent.
If a build fails because Protobuf headers are missing fixed-width integer
definitions, patch the affected Protobuf-generated files by adding
`#include <cstdint>`.
## Tests
The Rust simulator has its own tests:
```bash
cd backend-simulators/pim/pim-simulator
cargo test
```
## Repository Layout
- `src/PIM/` - PIM accelerator implementation.
- `test/PIM/` - PIM C++ unit tests.
- `validation/` - functional validation scripts, ONNX operation tests, network
slices, and pimsim config generation.
- `backend-simulators/pim/pim-simulator/` - in-tree Rust functional simulator.
- `backend-simulators/pim/pimsim-nn/` - performance simulator submodule.
- `pimcomp_utils/` - local comparison helpers for PIMCOMP-NN.
- `.github/actions/` and `.github/workflows/validate_operations.yml` - CI setup
for MLIR/Protobuf caching, building Raptor, and validation.
@@ -43,7 +43,7 @@ struct Args {
/// Comma separated list of (address,size) for memory output dump /// Comma separated list of (address,size) for memory output dump
#[arg(short, long, value_delimiter = ',', num_args = 1.., value_name = "ADDR,SIZE")] #[arg(short, long, value_delimiter = ',', num_args = 1.., value_name = "ADDR,SIZE")]
dump: Vec<i32>, dump: Vec<usize>,
} }
fn main() -> Result<()> { fn main() -> Result<()> {
@@ -67,7 +67,7 @@ fn main() -> Result<()> {
.lock() .lock()
.unwrap() .unwrap()
.init(executor.cpu().num_core(), args.output.clone()); .init(executor.cpu().num_core(), args.output.clone());
executor.execute(); executor.execute()?;
dump_memory(executor, &args)?; dump_memory(executor, &args)?;
Ok(()) Ok(())
} }
@@ -168,7 +168,7 @@ fn get_crossbars(config: &Value, args: &Args) -> anyhow::Result<HashMap<String,
} }
fn dump_memory(mut executor: pimcore::Executable, args: &Args) -> Result<()> { fn dump_memory(mut executor: pimcore::Executable, args: &Args) -> Result<()> {
let dumps: Vec<(i32, i32)> = args let dumps: Vec<(usize, usize)> = args
.dump .dump
.chunks_exact(2) .chunks_exact(2)
.map(|chunk| (chunk[0], chunk[1])) .map(|chunk| (chunk[0], chunk[1]))
@@ -1,3 +1,4 @@
use crate::utility::AddressArg;
use std::{collections::HashMap, fmt::Debug}; use std::{collections::HashMap, fmt::Debug};
use anyhow::{Context, Result, ensure}; use anyhow::{Context, Result, ensure};
@@ -9,6 +10,7 @@ use crate::{
pub mod crossbar; pub mod crossbar;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct CPU<'a> { pub struct CPU<'a> {
cores: Box<[Core<'a>]>, cores: Box<[Core<'a>]>,
@@ -91,30 +93,26 @@ impl<'a> Core<'a> {
self.memory.execute_load() self.memory.execute_load()
} }
pub fn execute_store<T>(&mut self, address: impl TryToUsize, element: &[T]) -> Result<()> pub fn execute_store<T>(&mut self, address: impl AddressArg, element: &[T]) -> Result<()>
where where
T: MemoryStorable, T: MemoryStorable,
{ {
let address = address.try_into().context("address can not be negative")?; let address = address.to_address_usize()?;
self.memory.execute_store(address, element) self.memory.execute_store(address, element)
} }
pub fn reserve_load( pub fn reserve_load(
&mut self, &mut self,
address: impl TryToUsize, address: impl AddressArg,
size: impl TryToUsize, size: impl TryToUsize,
) -> Result<&mut CoreMemory> { ) -> Result<&mut CoreMemory> {
let address = address.try_into().context("address can not be negative")?; let address = address.to_address_usize()?;
let size = size.try_into().context("size can not be negative")?; let size = size.try_into().context("size can not be negative")?;
self.memory.reserve_load(address, size) self.memory.reserve_load(address, size)
} }
pub fn set_register(&mut self, index: impl TryToUsize, value: i32) { pub fn set_register(&mut self, index: impl TryToUsize, value: i32) {
let index = index.try_into().expect("index can not be negative"); let index = index.try_into().expect("index can not be negative");
assert!(
value >= 0,
"Register cannot be negative if happens remove this and go check where it's used as usize"
);
self.registers[index] = value; self.registers[index] = value;
} }
@@ -123,11 +121,11 @@ impl<'a> Core<'a> {
self.registers[index] self.registers[index]
} }
pub fn load<T>(&mut self, address: impl TryToUsize, size: impl TryToUsize) -> Result<Vec<&[T]>> pub fn load<T>(&mut self, address: impl AddressArg, size: impl TryToUsize) -> Result<Vec<&[T]>>
where where
T: MemoryStorable, T: MemoryStorable,
{ {
let address = address.try_into().context("address can not be negative")?; let address = address.to_address_usize()?;
let size = size.try_into().context("size can not be negative")?; let size = size.try_into().context("size can not be negative")?;
self.memory.load(address, size) self.memory.load(address, size)
} }
@@ -141,8 +139,8 @@ impl<'a> Core<'a> {
(memory, crossbars) (memory, crossbars)
} }
pub fn memset(&mut self, address: impl TryToUsize, size: impl TryToUsize, val: u8) -> Result<()> { pub fn memset(&mut self, address: impl AddressArg, size: impl TryToUsize, val: u8) -> Result<()> {
let address = address.try_into().context("address can not be negative")?; let address = address.to_address_usize()?;
let size = size.try_into().context("size can not be negative")?; let size = size.try_into().context("size can not be negative")?;
self.memory.memset(address, size, val) self.memory.memset(address, size, val)
} }
@@ -1,3 +1,4 @@
use std::cmp::min;
use std::fmt::Debug; use std::fmt::Debug;
use anyhow::{Context, Result, bail, ensure}; use anyhow::{Context, Result, bail, ensure};
@@ -86,7 +87,7 @@ where {
size, size,
}; };
if self.memory.len() < address + size { if self.memory.len() < address + size {
self.memory.resize((address + size) * 2, 0); self.memory.resize(min((address + size) * 2, u32::MAX as usize), 0);
} }
self.load_requests.push(load_request); self.load_requests.push(load_request);
Ok(self) Ok(self)
@@ -1,5 +1,6 @@
#![allow(unused)] #![allow(unused)]
use anyhow::{Result, bail};
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
time::{Duration, SystemTime}, time::{Duration, SystemTime},
@@ -87,6 +88,11 @@ pub struct Executable<'a> {
send_recv: SendRecv, send_recv: SendRecv,
} }
struct DeadlockInfo {
cycle: String,
states: String,
}
fn print_status(core_instructions: &[CoreInstructions]) { fn print_status(core_instructions: &[CoreInstructions]) {
let mut tot_instructions = 0; let mut tot_instructions = 0;
let mut progress = 0; let mut progress = 0;
@@ -118,7 +124,7 @@ impl<'a> Executable<'a> {
} }
} }
pub fn execute<'b>(&'b mut self) pub fn execute<'b>(&'b mut self) -> Result<()>
where where
'a: 'b, 'a: 'b,
{ {
@@ -153,7 +159,13 @@ impl<'a> Executable<'a> {
} }
if (now.elapsed().unwrap() > Duration::from_secs(5)) { if (now.elapsed().unwrap() > Duration::from_secs(5)) {
print_status(cores_instructions); print_status(cores_instructions);
check_cycle(cpu, cores_instructions, send_recv); if let Some(deadlock) = detect_deadlock(cores_instructions) {
bail!(
"Deadlock cycle detected: {} [{}]",
deadlock.cycle,
deadlock.states
);
}
now = SystemTime::now(); now = SystemTime::now();
} }
} }
@@ -178,8 +190,23 @@ impl<'a> Executable<'a> {
} }
print_status(cores_instructions); print_status(cores_instructions);
if let Some(deadlock) = detect_deadlock(cores_instructions) {
bail!(
"Deadlock cycle detected: {} [{}]",
deadlock.cycle,
deadlock.states
);
}
if cores_instructions
.iter()
.any(|core_inst| core_inst.program_counter < core_inst.instructions.len())
{
bail!("Execution stalled with unfinished instructions");
}
#[cfg(feature = "profile_time")] #[cfg(feature = "profile_time")]
TRACER.lock().unwrap().report(); TRACER.lock().unwrap().report();
Ok(())
} }
pub fn cpu(&self) -> &CPU<'a> { pub fn cpu(&self) -> &CPU<'a> {
@@ -201,12 +228,12 @@ impl<'a> Executable<'a> {
} }
} }
fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv: &mut SendRecv) { fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockInfo> {
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
enum CoreState { enum CoreState {
SendingTo(i32), SendingTo(i32, i32),
ReceivingFrom(i32), ReceivingFrom(i32, i32),
Working, Working,
Halted, Halted,
} }
@@ -223,9 +250,9 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
let (this_core, target_core) = data.get_core_immcore(); let (this_core, target_core) = data.get_core_immcore();
if isa_recv(functor_address) { if isa_recv(functor_address) {
states.insert(this_core, CoreState::ReceivingFrom(target_core)); states.insert(this_core, CoreState::ReceivingFrom(target_core, data.imm_len()));
} else if isa_send(functor_address) { } else if isa_send(functor_address) {
states.insert(this_core, CoreState::SendingTo(target_core)); states.insert(this_core, CoreState::SendingTo(target_core, data.imm_len()));
} else { } else {
states.insert(this_core, CoreState::Working); states.insert(this_core, CoreState::Working);
} }
@@ -235,15 +262,15 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
for (&core_id, state) in states.iter() { for (&core_id, state) in states.iter() {
match state { match state {
CoreState::SendingTo(target_core) => { CoreState::SendingTo(target_core, size) => {
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted); let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
if target_state != &CoreState::ReceivingFrom(core_id) { if target_state != &CoreState::ReceivingFrom(core_id, *size) {
wait_for.insert(core_id, *target_core); wait_for.insert(core_id, *target_core);
} }
} }
CoreState::ReceivingFrom(target_core) => { CoreState::ReceivingFrom(target_core, size) => {
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted); let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
if target_state != &CoreState::SendingTo(core_id) { if target_state != &CoreState::SendingTo(core_id, *size) {
wait_for.insert(core_id, *target_core); wait_for.insert(core_id, *target_core);
} }
} }
@@ -272,18 +299,41 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
if in_path.contains(&waiting_for) { if in_path.contains(&waiting_for) {
let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap(); let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap();
let cycle = &path[cycle_start..]; let cycle = &path[cycle_start..];
let format_core = |core: &i32| (core - 1).to_string();
let cycle_str = cycle let cycle_str = cycle
.iter() .iter()
.map(|c| c.to_string()) .map(format_core)
.collect::<Vec<_>>() .collect::<Vec<_>>()
.join(" -> "); .join(" -> ");
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for); let cycle = cycle
.iter()
.copied()
.chain(std::iter::once(waiting_for))
.collect::<Vec<_>>();
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for - 1);
let states_msg = cycle
.iter()
.filter_map(|core| {
states.get(core).map(|state| match state {
CoreState::SendingTo(target, size) => {
format!("core {} send {}B -> {}", core - 1, size, target - 1)
}
CoreState::ReceivingFrom(source, size) => {
format!("core {} recv {}B <- {}", core - 1, size, source - 1)
}
CoreState::Working => format!("core {} working", core - 1),
CoreState::Halted => format!("core {} halted", core - 1),
})
})
.collect::<Vec<_>>()
.join(", ");
println!("Fatal: Deadlock cycle detected: {}", cycle_msg); return Some(DeadlockInfo {
// bail!("Deadlock detected: {}", cycle_msg); cycle: cycle_msg,
break; // Stop tracing states: states_msg,
});
} }
// Hit a known branch that didn't result in a cycle // Hit a known branch that didn't result in a cycle
@@ -294,6 +344,7 @@ fn check_cycle(cpu: &mut CPU, cores_instructions: &[CoreInstructions], send_recv
current_core = waiting_for; current_core = waiting_for;
} }
} }
None
} }
fn handle_wait_sync<'a, 'b, 'c>( fn handle_wait_sync<'a, 'b, 'c>(
@@ -1,7 +1,45 @@
use anyhow::{Result,Context};
use std::{fmt::Debug, mem::transmute}; use std::{fmt::Debug, mem::transmute};
use crate::memory_manager::type_traits::TryToUsize;
pub trait AddressArg {
fn to_address_usize(self) -> Result<usize>;
}
impl AddressArg for usize {
fn to_address_usize(self) -> Result<usize> {
Ok(self)
}
}
impl AddressArg for u32 {
fn to_address_usize(self) -> Result<usize> {
Ok(self as usize)
}
}
impl AddressArg for u64 {
fn to_address_usize(self) -> Result<usize> {
usize::try_from(self).context("address does not fit in usize")
}
}
impl AddressArg for i32 {
fn to_address_usize(self) -> Result<usize> {
Ok(self as u32 as usize)
}
}
impl AddressArg for i64 {
fn to_address_usize(self) -> Result<usize> {
usize::try_from(self).context("address can not be negative")
}
}
fn address_to_usize(address: i32) -> usize {
address as u32 as usize
}
fn add_offset_impl(address: usize, offset_select : i32, offset_value : i32, id:i32) -> usize{ fn add_offset_impl(address: usize, offset_select : i32, offset_value : i32, id:i32) -> usize{
assert!(offset_select == 1 || offset_select == 2 || offset_select == 4 || offset_value == 0, "offset_select not a bit field"); assert!(offset_select == 1 || offset_select == 2 || offset_select == 4 || offset_value == 0, "offset_select not a bit field");
@@ -14,21 +52,21 @@ fn add_offset_impl(address: usize, offset_select : i32, offset_value : i32, id:i
} }
pub fn add_offset_rd(address: impl TryToUsize, offset_select : i32, offset_value : i32) -> usize pub fn add_offset_rd(address: i32, offset_select : i32, offset_value : i32) -> usize
{ {
let address = address.try_into().expect("address can not be negative"); let address = address_to_usize(address);
add_offset_impl(address, offset_select, offset_value, 4) add_offset_impl(address, offset_select, offset_value, 4)
} }
pub fn add_offset_r1(address: impl TryToUsize, offset_select : i32, offset_value : i32) -> usize pub fn add_offset_r1(address: i32, offset_select : i32, offset_value : i32) -> usize
{ {
let address = address.try_into().expect("address can not be negative"); let address = address_to_usize(address);
add_offset_impl(address, offset_select, offset_value, 1) add_offset_impl(address, offset_select, offset_value, 1)
} }
pub fn add_offset_r2(address: impl TryToUsize, offset_select : i32, offset_value : i32) -> usize pub fn add_offset_r2(address: i32, offset_select : i32, offset_value : i32) -> usize
{ {
let address = address.try_into().expect("address can not be negative"); let address = address_to_usize(address);
add_offset_impl(address, offset_select, offset_value, 2) add_offset_impl(address, offset_select, offset_value, 2)
} }
Symlink
+1
View File
@@ -0,0 +1 @@
/home/ilgeco/Project/Raptor/build_debug/
+254
View File
@@ -0,0 +1,254 @@
diff --git a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt
index 0b7e8cc..32964aa 100644
--- a/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt
+++ b/src/PIM/Conversion/ONNXToSpatial/CMakeLists.txt
@@ -22,6 +22,7 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Gather.cpp
Patterns/Tensor/Resize.cpp
Patterns/Tensor/Reshape.cpp
+ Patterns/Tensor/Slice.cpp
Patterns/Tensor/Split.cpp
Patterns/Tensor/Transpose.cpp
ONNXToSpatialPass.cpp
diff --git a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp
index edf311e..c3d42f7 100644
--- a/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp
+++ b/src/PIM/Conversion/ONNXToSpatial/ONNXToSpatialPass.cpp
@@ -138,6 +138,7 @@ void ONNXToSpatialPass::runOnOperation() {
target.addIllegalOp<ONNXGatherOp>();
target.addIllegalOp<ONNXReshapeOp>();
target.addIllegalOp<ONNXResizeOp>();
+ target.addIllegalOp<ONNXSliceOp>();
target.addIllegalOp<ONNXLRNOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>();
target.addIllegalOp<ONNXSplitOp>();
diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp
index ffa0b1f..0a747e9 100644
--- a/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp
+++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.cpp
@@ -22,6 +22,7 @@ void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateGatherPatterns(patterns, ctx);
populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
+ populateSlicePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx);
populateTransposePatterns(patterns, ctx);
}
diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp
index e58729e..c040536 100644
--- a/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp
+++ b/src/PIM/Conversion/ONNXToSpatial/Patterns.hpp
@@ -29,6 +29,7 @@ void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext
void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
+void populateSlicePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Slice.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Slice.cpp
new file mode 100644
index 0000000..3f8867f
--- /dev/null
+++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Tensor/Slice.cpp
@@ -0,0 +1,200 @@
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "llvm/ADT/SmallVector.h"
+
+#include <algorithm>
+#include <optional>
+
+#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
+#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
+#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
+#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
+#include "src/Dialect/ONNX/ONNXOps.hpp"
+
+using namespace mlir;
+
+namespace onnx_mlir {
+namespace {
+
+static DenseElementsAttr getDenseConstantAttr(Value value) {
+ if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
+ return dyn_cast<DenseElementsAttr>(constantOp.getValue());
+ if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
+ return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
+ return nullptr;
+}
+
+static FailureOr<SmallVector<int64_t>> getConstantIntValues(Value value) {
+ auto denseAttr = dyn_cast_or_null<DenseIntElementsAttr>(getDenseConstantAttr(value));
+ if (!denseAttr)
+ return failure();
+ return SmallVector<int64_t>(denseAttr.getValues<int64_t>().begin(), denseAttr.getValues<int64_t>().end());
+}
+
+static bool isNoneValueLike(Value value) { return isa_and_nonnull<ONNXNoneOp>(value.getDefiningOp()); }
+
+static FailureOr<Value> buildSlice(Value data,
+ RankedTensorType dataType,
+ RankedTensorType resultType,
+ ArrayRef<int64_t> starts,
+ ArrayRef<int64_t> ends,
+ std::optional<ArrayRef<int64_t>> axes,
+ std::optional<ArrayRef<int64_t>> steps,
+ ConversionPatternRewriter& rewriter,
+ Location loc) {
+ int64_t rank = dataType.getRank();
+ if (!dataType.hasStaticShape() || !resultType.hasStaticShape() || resultType.getRank() != rank)
+ return failure();
+
+ if (starts.size() != ends.size())
+ return failure();
+ if (axes && axes->size() != starts.size())
+ return failure();
+ if (steps && steps->size() != starts.size())
+ return failure();
+
+ SmallVector<int64_t> normalizedAxes;
+ if (axes) {
+ SmallVector<bool> seenAxes(rank, false);
+ normalizedAxes.reserve(axes->size());
+ for (int64_t axis : *axes) {
+ auto normalizedAxis = normalizeAxisChecked(axis, rank);
+ if (failed(normalizedAxis))
+ return failure();
+ if (seenAxes[*normalizedAxis])
+ return failure();
+ seenAxes[*normalizedAxis] = true;
+ normalizedAxes.push_back(*normalizedAxis);
+ }
+ }
+ else {
+ if (starts.size() > static_cast<size_t>(rank))
+ return failure();
+ normalizedAxes.reserve(starts.size());
+ for (size_t i = 0; i < starts.size(); ++i)
+ normalizedAxes.push_back(static_cast<int64_t>(i));
+ }
+
+ SmallVector<int64_t> normalizedSteps;
+ if (steps)
+ normalizedSteps.assign(steps->begin(), steps->end());
+ else
+ normalizedSteps.assign(starts.size(), 1);
+
+ SmallVector<int64_t> computedShape(dataType.getShape().begin(), dataType.getShape().end());
+ SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, rank);
+ SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, dataType.getShape());
+ SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, rank);
+
+ for (auto [sliceIndex, axis] : llvm::enumerate(normalizedAxes)) {
+ int64_t step = normalizedSteps[sliceIndex];
+ if (step <= 0)
+ return failure();
+
+ int64_t dimSize = dataType.getShape()[axis];
+ int64_t start = starts[sliceIndex];
+ int64_t end = ends[sliceIndex];
+
+ if (start < 0)
+ start += dimSize;
+ if (end < 0)
+ end += dimSize;
+
+ start = std::clamp(start, int64_t {0}, dimSize);
+ end = std::clamp(end, int64_t {0}, dimSize);
+
+ int64_t extent = std::max(end - start, int64_t {0});
+ int64_t size = (extent + step - 1) / step;
+
+ offsets[axis] = rewriter.getIndexAttr(start);
+ sizes[axis] = rewriter.getIndexAttr(size);
+ strides[axis] = rewriter.getIndexAttr(step);
+ computedShape[axis] = size;
+ }
+
+ if (llvm::ArrayRef(computedShape) != resultType.getShape())
+ return failure();
+
+ return tensor::ExtractSliceOp::create(rewriter, loc, resultType, data, offsets, sizes, strides).getResult();
+}
+
+struct Slice final : OpConversionPattern<ONNXSliceOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(ONNXSliceOp sliceOp,
+ ONNXSliceOpAdaptor adaptor,
+ ConversionPatternRewriter& rewriter) const override {
+ auto dataType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
+ auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
+ if (!dataType || !resultType || !dataType.hasStaticShape() || !resultType.hasStaticShape())
+ return failure();
+
+ auto starts = getConstantIntValues(adaptor.getStarts());
+ auto ends = getConstantIntValues(adaptor.getEnds());
+ if (failed(starts))
+ return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant starts");
+ if (failed(ends))
+ return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant ends");
+
+ std::optional<SmallVector<int64_t>> axes;
+ if (!isNoneValueLike(adaptor.getAxes())) {
+ auto parsedAxes = getConstantIntValues(adaptor.getAxes());
+ if (failed(parsedAxes))
+ return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant axes when present");
+ axes = std::move(*parsedAxes);
+ }
+
+ std::optional<SmallVector<int64_t>> steps;
+ if (!isNoneValueLike(adaptor.getSteps())) {
+ auto parsedSteps = getConstantIntValues(adaptor.getSteps());
+ if (failed(parsedSteps))
+ return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant steps when present");
+ steps = std::move(*parsedSteps);
+ if (llvm::any_of(*steps, [](int64_t step) { return step <= 0; }))
+ return rewriter.notifyMatchFailure(sliceOp, "supports only positive constant steps");
+ }
+
+ ArrayRef<int64_t> startsRef = *starts;
+ ArrayRef<int64_t> endsRef = *ends;
+ std::optional<ArrayRef<int64_t>> axesRef = axes ? std::optional<ArrayRef<int64_t>>(ArrayRef<int64_t>(*axes))
+ : std::nullopt;
+ std::optional<ArrayRef<int64_t>> stepsRef = steps ? std::optional<ArrayRef<int64_t>>(ArrayRef<int64_t>(*steps))
+ : std::nullopt;
+
+ Location loc = sliceOp.getLoc();
+ auto tryBuildSlice = [&](Value data) {
+ return buildSlice(data, dataType, resultType, startsRef, endsRef, axesRef, stepsRef, rewriter, loc);
+ };
+
+ if (isCompileTimeComputable(adaptor.getData())) {
+ auto sliced = tryBuildSlice(adaptor.getData());
+ if (failed(sliced))
+ return rewriter.notifyMatchFailure(sliceOp, "failed to normalize static slice parameters");
+ rewriter.replaceOp(sliceOp, *sliced);
+ return success();
+ }
+
+ auto computeOp =
+ createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
+ auto sliced = tryBuildSlice(data);
+ if (failed(sliced))
+ return failure();
+ spatial::SpatYieldOp::create(rewriter, loc, *sliced);
+ return success();
+ });
+ if (failed(computeOp))
+ return rewriter.notifyMatchFailure(sliceOp, "failed to build runtime tensor.extract_slice lowering");
+
+ rewriter.replaceOp(sliceOp, computeOp->getResults());
+ return success();
+ }
+};
+
+} // namespace
+
+void populateSlicePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<Slice>(ctx); }
+
+} // namespace onnx_mlir
+56 -1
View File
@@ -10,6 +10,56 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT}) set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT}) set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
set(PIM_GENERATED_PATH_SHIM_TARGET "")
get_filename_component(PIM_BIN_ROOT_NAME "${PIM_BIN_ROOT}" NAME)
if (PIM_BIN_ROOT_NAME STREQUAL "raptor-external")
get_filename_component(PIM_GENERATED_PATH_SHIM_ROOT "${PIM_BIN_ROOT}" DIRECTORY)
set(PIM_GENERATED_PATH_SHIM_OUTPUTS)
function(add_pim_generated_path_shim relative_path)
set(real_file "${PIM_BIN_ROOT}/${relative_path}")
set(shim_file "${PIM_GENERATED_PATH_SHIM_ROOT}/${relative_path}")
get_filename_component(shim_dir "${shim_file}" DIRECTORY)
add_custom_command(
OUTPUT "${shim_file}"
DEPENDS "${real_file}"
COMMAND "${CMAKE_COMMAND}" -E make_directory "${shim_dir}"
COMMAND "${CMAKE_COMMAND}" -E rm -f "${shim_file}"
COMMAND "${CMAKE_COMMAND}" -E create_symlink "${real_file}" "${shim_file}"
VERBATIM
)
list(APPEND PIM_GENERATED_PATH_SHIM_OUTPUTS "${shim_file}")
set(PIM_GENERATED_PATH_SHIM_OUTPUTS "${PIM_GENERATED_PATH_SHIM_OUTPUTS}" PARENT_SCOPE)
endfunction()
file(GLOB_RECURSE pim_generated_path_scan_sources
CONFIGURE_DEPENDS
"${PIM_SRC_ROOT}/*.cpp"
"${PIM_SRC_ROOT}/*.hpp"
)
set(pim_generated_path_shims)
foreach (source_file IN LISTS pim_generated_path_scan_sources)
file(READ "${source_file}" source_contents)
string(REGEX MATCHALL "#include \"src/Accelerators/PIM/[^\"]+\\.inc\"" source_inc_matches "${source_contents}")
foreach (inc_match IN LISTS source_inc_matches)
string(REGEX REPLACE "^#include \"src/Accelerators/PIM/(.+)\"$" "\\1" relative_inc_path "${inc_match}")
list(APPEND pim_generated_path_shims "${relative_inc_path}")
endforeach ()
endforeach ()
list(REMOVE_DUPLICATES pim_generated_path_shims)
foreach (relative_inc_path IN LISTS pim_generated_path_shims)
add_pim_generated_path_shim("${relative_inc_path}")
endforeach ()
add_custom_target(OMPimGeneratedPathShims DEPENDS ${PIM_GENERATED_PATH_SHIM_OUTPUTS})
set(PIM_GENERATED_PATH_SHIM_TARGET OMPimGeneratedPathShims)
endif ()
set(PIM_PUBLIC_INCLUDE_DIRS set(PIM_PUBLIC_INCLUDE_DIRS
${ONNX_MLIR_SRC_ROOT}/include ${ONNX_MLIR_SRC_ROOT}/include
${ONNX_MLIR_SRC_ROOT} ${ONNX_MLIR_SRC_ROOT}
@@ -37,6 +87,9 @@ set(PIM_GENERATED_INCLUDE_DIRS
function(add_pim_library name) function(add_pim_library name)
add_onnx_mlir_library(${name} STATIC ${ARGN}) add_onnx_mlir_library(${name} STATIC ${ARGN})
if (PIM_GENERATED_PATH_SHIM_TARGET)
add_dependencies(${name} ${PIM_GENERATED_PATH_SHIM_TARGET})
endif ()
endfunction() endfunction()
add_subdirectory(Dialect) add_subdirectory(Dialect)
@@ -68,6 +121,8 @@ add_pim_library(OMPIMAccel
OMSpatialToPim OMSpatialToPim
OMPimCommon OMPimCommon
OMPimBufferization OMPimBufferization
OMPimStaticMemoryCoalescing OMPimMemoryCoalescing
OMPimHostConstantFolding
OMPimVerification
MLIRTensorInferTypeOpInterfaceImpl MLIRTensorInferTypeOpInterfaceImpl
) )
+7
View File
@@ -1,10 +1,15 @@
add_pim_library(OMPimCommon add_pim_library(OMPimCommon
IR/AffineUtils.cpp
IR/AddressAnalysis.cpp IR/AddressAnalysis.cpp
IR/BatchCoreUtils.cpp
IR/ConstantUtils.cpp
IR/CoreBlockUtils.cpp IR/CoreBlockUtils.cpp
IR/EntryPointUtils.cpp IR/EntryPointUtils.cpp
IR/LoopUtils.cpp
IR/ShapeUtils.cpp IR/ShapeUtils.cpp
IR/SubviewUtils.cpp IR/SubviewUtils.cpp
IR/WeightUtils.cpp IR/WeightUtils.cpp
Support/CheckedArithmetic.cpp
Support/DebugDump.cpp Support/DebugDump.cpp
Support/Diagnostics.cpp Support/Diagnostics.cpp
Support/FileSystemUtils.cpp Support/FileSystemUtils.cpp
@@ -16,6 +21,8 @@ add_pim_library(OMPimCommon
${PIM_PUBLIC_INCLUDE_DIRS} ${PIM_PUBLIC_INCLUDE_DIRS}
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRLinalgDialect
MLIRSCFDialect
onnx onnx
SpatialOps SpatialOps
PimOps PimOps
+555 -7
View File
@@ -1,7 +1,11 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include <limits>
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -28,6 +32,14 @@ mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledg
return value; return value;
} }
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value);
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value);
template <typename... Args>
CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) {
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::forward<Args>(args)...));
}
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge); value = resolveAlias(value, knowledge);
@@ -55,6 +67,288 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
} }
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge); llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
static llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefStrides(mlir::MemRefType type) {
llvm::SmallVector<int64_t> strides;
int64_t offset = 0;
if (failed(type.getStridesAndOffset(strides, offset)))
return mlir::failure();
if (llvm::any_of(strides, mlir::ShapedType::isDynamic))
return mlir::failure();
return strides;
}
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
const StaticValueKnowledge* knowledge) {
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
if (!getGlobalOp)
return mlir::failure();
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
if (!denseAttr || !globalType || !globalType.hasStaticShape())
return mlir::failure();
auto elementType = denseAttr.getElementType();
if (!elementType.isIndex() && !elementType.isInteger())
return mlir::failure();
llvm::SmallVector<int64_t> indices;
indices.reserve(loadOp.getIndices().size());
for (mlir::Value index : loadOp.getIndices()) {
auto resolvedIndex = resolveIndexValueImpl(index, knowledge);
if (failed(resolvedIndex))
return mlir::failure();
indices.push_back(*resolvedIndex);
}
if (indices.size() != static_cast<size_t>(globalType.getRank()))
return mlir::failure();
auto strides = computeRowMajorStrides(globalType.getShape());
int64_t linearIndex = linearizeIndex(indices, strides);
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
return mlir::failure();
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
}
static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) {
switch (predicate) {
case mlir::arith::CmpIPredicate::eq: return lhs == rhs;
case mlir::arith::CmpIPredicate::ne: return lhs != rhs;
case mlir::arith::CmpIPredicate::slt: return lhs < rhs;
case mlir::arith::CmpIPredicate::sle: return lhs <= rhs;
case mlir::arith::CmpIPredicate::sgt: return lhs > rhs;
case mlir::arith::CmpIPredicate::sge: return lhs >= rhs;
case mlir::arith::CmpIPredicate::ult: return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::ule: return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::ugt: return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::uge: return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
}
llvm_unreachable("unknown cmpi predicate");
}
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr,
const StaticValueKnowledge& knowledge) {
if (!expr.node)
return mlir::failure();
switch (expr.node->kind) {
case CompiledIndexExprNode::Kind::Constant: return expr.node->constant;
case CompiledIndexExprNode::Kind::Symbol: {
auto value = resolveAlias(expr.node->symbol, &knowledge);
auto iter = knowledge.indexValues.find(value);
if (iter != knowledge.indexValues.end())
return iter->second;
return mlir::failure();
}
case CompiledIndexExprNode::Kind::Add:
case CompiledIndexExprNode::Kind::Sub:
case CompiledIndexExprNode::Kind::Mul:
case CompiledIndexExprNode::Kind::DivUI:
case CompiledIndexExprNode::Kind::DivSI:
case CompiledIndexExprNode::Kind::RemUI:
case CompiledIndexExprNode::Kind::RemSI:
case CompiledIndexExprNode::Kind::MinUI:
case CompiledIndexExprNode::Kind::CmpI: {
auto lhs = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
auto rhs = evaluateCompiledIndexExpr(expr.node->operands[1], knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
switch (expr.node->kind) {
case CompiledIndexExprNode::Kind::Add: return *lhs + *rhs;
case CompiledIndexExprNode::Kind::Sub: return *lhs - *rhs;
case CompiledIndexExprNode::Kind::Mul: return *lhs * *rhs;
case CompiledIndexExprNode::Kind::DivUI:
if (*rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
case CompiledIndexExprNode::Kind::DivSI:
if (*rhs == 0 || (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1))
return mlir::failure();
return *lhs / *rhs;
case CompiledIndexExprNode::Kind::RemUI:
if (*rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
case CompiledIndexExprNode::Kind::RemSI:
if (*rhs == 0)
return mlir::failure();
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
return 0;
return *lhs % *rhs;
case CompiledIndexExprNode::Kind::MinUI:
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
case CompiledIndexExprNode::Kind::CmpI: return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
default: llvm_unreachable("unexpected binary compiled index kind");
}
}
case CompiledIndexExprNode::Kind::Select: {
auto condition = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
if (failed(condition))
return mlir::failure();
return evaluateCompiledIndexExpr(*condition != 0 ? expr.node->operands[1] : expr.node->operands[2], knowledge);
}
case CompiledIndexExprNode::Kind::ConstantGlobalLoad: {
if (!expr.node->globalOp || !expr.node->globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*expr.node->globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(expr.node->globalOp.getType());
if (!denseAttr || !globalType)
return mlir::failure();
llvm::SmallVector<int64_t> indices;
indices.reserve(expr.node->operands.size());
for (const CompiledIndexExpr& operand : expr.node->operands) {
auto resolvedIndex = evaluateCompiledIndexExpr(operand, knowledge);
if (failed(resolvedIndex))
return mlir::failure();
indices.push_back(*resolvedIndex);
}
int64_t linearIndex = linearizeIndex(indices, expr.node->globalStrides);
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
return mlir::failure();
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
}
}
llvm_unreachable("unknown compiled index kind");
}
llvm::FailureOr<CompiledIndexExpr> compileConstantGlobalLoad(mlir::memref::LoadOp loadOp) {
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
if (!getGlobalOp)
return mlir::failure();
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
if (!denseAttr || !globalType || !globalType.hasStaticShape())
return mlir::failure();
auto elementType = denseAttr.getElementType();
if (!elementType.isIndex() && !elementType.isInteger())
return mlir::failure();
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::ConstantGlobalLoad;
expr.globalOp = globalOp;
expr.globalStrides = computeRowMajorStrides(globalType.getShape());
expr.operands.reserve(loadOp.getIndices().size());
for (mlir::Value index : loadOp.getIndices()) {
auto compiledIndex = compileIndexValueImpl(index);
if (failed(compiledIndex))
return mlir::failure();
expr.operands.push_back(*compiledIndex);
}
return makeCompiledIndexExpr(std::move(expr));
}
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value) {
if (auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>()) {
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue())) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = integerAttr.getInt();
return makeCompiledIndexExpr(std::move(expr));
}
}
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Symbol;
expr.symbol = value;
return makeCompiledIndexExpr(std::move(expr));
}
auto buildBinaryExpr = [&](CompiledIndexExprNode::Kind kind, mlir::Value lhsValue, mlir::Value rhsValue) {
auto lhs = compileIndexValueImpl(lhsValue);
auto rhs = compileIndexValueImpl(rhsValue);
if (failed(lhs) || failed(rhs))
return llvm::FailureOr<CompiledIndexExpr>(mlir::failure());
CompiledIndexExprNode expr;
expr.kind = kind;
expr.operands = {*lhs, *rhs};
return llvm::FailureOr<CompiledIndexExpr>(makeCompiledIndexExpr(std::move(expr)));
};
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
return compileIndexValueImpl(indexCastOp.getIn());
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Add, addOp.getLhs(), addOp.getRhs());
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Sub, subOp.getLhs(), subOp.getRhs());
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Mul, mulOp.getLhs(), mulOp.getRhs());
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivUI, divOp.getLhs(), divOp.getRhs());
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivSI, divOp.getLhs(), divOp.getRhs());
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::MinUI, minOp.getLhs(), minOp.getRhs());
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemUI, remOp.getLhs(), remOp.getRhs());
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemSI, remOp.getLhs(), remOp.getRhs());
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
auto expr = buildBinaryExpr(CompiledIndexExprNode::Kind::CmpI, cmpOp.getLhs(), cmpOp.getRhs());
if (failed(expr))
return mlir::failure();
auto exprNode = std::make_shared<CompiledIndexExprNode>(*expr->node);
exprNode->predicate = cmpOp.getPredicate();
return CompiledIndexExpr(exprNode);
}
if (auto maxOp = mlir::dyn_cast<mlir::arith::MaxUIOp>(definingOp)) {
auto lhs = compileIndexValueImpl(maxOp.getLhs());
auto rhs = compileIndexValueImpl(maxOp.getRhs());
if (failed(lhs) || failed(rhs))
return mlir::failure();
CompiledIndexExprNode cmpExpr;
cmpExpr.kind = CompiledIndexExprNode::Kind::CmpI;
cmpExpr.predicate = mlir::arith::CmpIPredicate::uge;
cmpExpr.operands = {*lhs, *rhs};
CompiledIndexExprNode selectExpr;
selectExpr.kind = CompiledIndexExprNode::Kind::Select;
selectExpr.operands = {makeCompiledIndexExpr(std::move(cmpExpr)), *lhs, *rhs};
return makeCompiledIndexExpr(std::move(selectExpr));
}
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
auto condition = compileIndexValueImpl(selectOp.getCondition());
auto trueValue = compileIndexValueImpl(selectOp.getTrueValue());
auto falseValue = compileIndexValueImpl(selectOp.getFalseValue());
if (failed(condition) || failed(trueValue) || failed(falseValue))
return mlir::failure();
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Select;
expr.operands = {*condition, *trueValue, *falseValue};
return makeCompiledIndexExpr(std::move(expr));
}
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
return compileConstantGlobalLoad(loadOp);
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Symbol;
expr.symbol = value;
return makeCompiledIndexExpr(std::move(expr));
}
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) { llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge); value = resolveAlias(value, knowledge);
@@ -110,6 +404,24 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs)); return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
} }
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
return mlir::failure();
return *lhs / *rhs;
}
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
}
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) { if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge); auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge); auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
@@ -118,6 +430,34 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs)); return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
} }
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
return 0;
return *lhs % *rhs;
}
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(cmpOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(cmpOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return evaluateCmpPredicate(cmpOp.getPredicate(), *lhs, *rhs) ? 1 : 0;
}
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
auto condition = resolveIndexValueImpl(selectOp.getCondition(), knowledge);
if (failed(condition))
return mlir::failure();
return resolveIndexValueImpl(*condition != 0 ? selectOp.getTrueValue() : selectOp.getFalseValue(), knowledge);
}
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
return resolveConstantGlobalLoad(loadOp, knowledge);
return mlir::failure(); return mlir::failure();
} }
@@ -209,8 +549,10 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides)) if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return mlir::failure(); return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); auto sourceStrides = getStaticMemRefStrides(sourceType);
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8; if (failed(sourceStrides))
return mlir::failure();
byteOffset += linearizeIndex(offsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
value = resolveAlias(subviewOp.getSource(), knowledge); value = resolveAlias(subviewOp.getSource(), knowledge);
continue; continue;
} }
@@ -235,17 +577,206 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
} }
} }
} // namespace llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value) {
int64_t constantByteOffset = 0;
CompiledIndexExpr byteOffsetExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = 0;
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
}
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); } while (true) {
if (mlir::isa<mlir::BlockArgument>(value))
return CompiledAddressExpr {value, byteOffsetExpr};
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
if (!tiedOperand)
return mlir::failure();
value = tiedOperand->get();
continue;
}
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
mlir::Value yieldedValue = yieldOp.getOperand(result.getResultNumber());
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = forOp.getInitArgs()[blockArgument.getArgNumber() - 1];
continue;
}
}
value = yieldedValue;
continue;
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return mlir::failure();
llvm::SmallVector<int64_t> staticSizes;
staticSizes.reserve(subviewOp.getMixedSizes().size());
llvm::SmallVector<int64_t> staticStrides;
staticStrides.reserve(subviewOp.getMixedStrides().size());
llvm::SmallVector<int64_t> staticOffsets;
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
bool hasOnlyStaticOffsets = true;
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets())
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
hasOnlyStaticOffsets = false;
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
auto attr = mlir::dyn_cast<mlir::Attribute>(size);
if (!attr)
return mlir::failure();
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
}
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
auto attr = mlir::dyn_cast<mlir::Attribute>(stride);
if (!attr)
return mlir::failure();
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
}
if (!isContiguousSubviewWithDynamicOffsets(
sourceType.getShape(), subviewOp.getMixedOffsets(), staticSizes, staticStrides)) {
return mlir::failure();
}
if (hasOnlyStaticOffsets) {
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
return mlir::failure();
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(sourceStrides))
return mlir::failure();
constantByteOffset +=
linearizeIndex(staticOffsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
}
else {
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(sourceStrides))
return mlir::failure();
CompiledIndexExpr offsetExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = 0;
offsetExpr = makeCompiledIndexExpr(std::move(expr));
}
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) {
CompiledIndexExpr operandExpr;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(mixedOffset)) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = mlir::cast<mlir::IntegerAttr>(attr).getInt() * sourceStride
* getElementTypeSizeInBytes(subviewType.getElementType());
operandExpr = makeCompiledIndexExpr(std::move(expr));
}
else {
auto compiledOffset = compileIndexValueImpl(mlir::cast<mlir::Value>(mixedOffset));
if (failed(compiledOffset))
return mlir::failure();
CompiledIndexExpr scaleExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = sourceStride * getElementTypeSizeInBytes(subviewType.getElementType());
scaleExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Mul;
expr.operands = {*compiledOffset, scaleExpr};
operandExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {offsetExpr, operandExpr};
offsetExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExpr constantExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = constantByteOffset;
constantExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {constantExpr, offsetExpr};
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
constantByteOffset = 0;
}
value = subviewOp.getSource();
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp)) {
if (constantByteOffset != 0) {
CompiledIndexExpr constantExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = constantByteOffset;
constantExpr = makeCompiledIndexExpr(std::move(expr));
}
if (byteOffsetExpr.node->kind == CompiledIndexExprNode::Kind::Constant && byteOffsetExpr.node->constant == 0)
byteOffsetExpr = constantExpr;
else {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {constantExpr, byteOffsetExpr};
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
}
}
return CompiledAddressExpr {value, byteOffsetExpr};
}
return mlir::failure();
}
}
} // namespace
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) { llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
return resolveIndexValueImpl(value, &knowledge); return resolveIndexValueImpl(value, &knowledge);
} }
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) { llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value) { return compileIndexValueImpl(value); }
return resolveContiguousAddressImpl(value, nullptr);
}
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value, llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge) { const StaticValueKnowledge& knowledge) {
@@ -256,4 +787,21 @@ mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledg
return resolveLoopCarriedAliasImpl(value, &knowledge); return resolveLoopCarriedAliasImpl(value, &knowledge);
} }
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value) {
return compileContiguousAddressExprImpl(value);
}
llvm::FailureOr<int64_t> CompiledIndexExpr::evaluate(const StaticValueKnowledge& knowledge) const {
return evaluateCompiledIndexExpr(*this, knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress> CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge,
std::optional<unsigned> lane) const {
(void) lane;
auto resolvedOffset = byteOffset.evaluate(knowledge);
if (failed(resolvedOffset))
return mlir::failure();
return ResolvedContiguousAddress {base, *resolvedOffset};
}
} // namespace onnx_mlir } // namespace onnx_mlir
+55 -4
View File
@@ -1,10 +1,14 @@
#pragma once #pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include <memory>
#include <optional>
namespace onnx_mlir { namespace onnx_mlir {
/// Describes a value as a base addressable object plus a statically known /// Describes a value as a base addressable object plus a statically known
@@ -23,21 +27,68 @@ struct StaticValueKnowledge {
StaticValueKnowledge() {} StaticValueKnowledge() {}
}; };
struct CompiledIndexExprNode;
struct CompiledIndexExpr {
std::shared_ptr<CompiledIndexExprNode> node;
CompiledIndexExpr() = default;
explicit CompiledIndexExpr(std::shared_ptr<CompiledIndexExprNode> node)
: node(std::move(node)) {}
llvm::FailureOr<int64_t> evaluate(const StaticValueKnowledge& knowledge) const;
};
struct CompiledIndexExprNode {
enum class Kind {
Constant,
Symbol,
Add,
Sub,
Mul,
DivUI,
DivSI,
RemUI,
RemSI,
MinUI,
CmpI,
Select,
ConstantGlobalLoad
};
Kind kind = Kind::Constant;
int64_t constant = 0;
mlir::Value symbol;
mlir::arith::CmpIPredicate predicate = mlir::arith::CmpIPredicate::eq;
mlir::memref::GlobalOp globalOp;
llvm::SmallVector<int64_t, 4> globalStrides;
llvm::SmallVector<CompiledIndexExpr, 4> operands;
};
struct CompiledAddressExpr {
mlir::Value base;
CompiledIndexExpr byteOffset;
llvm::FailureOr<ResolvedContiguousAddress> evaluate(const StaticValueKnowledge& knowledge,
std::optional<unsigned> lane) const;
};
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp); mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
/// Resolves a value to contiguous backing storage when that storage can be /// Resolves a value to contiguous backing storage when that storage can be
/// proven statically from aliases, DPS ties, casts, and subviews. /// proven statically from aliases, DPS ties, casts, and subviews.
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value, llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge); const StaticValueKnowledge& knowledge = {});
/// Statically evaluates index-like SSA values, including simple integer /// Statically evaluates index-like SSA values, including simple integer
/// arithmetic and loop facts recorded in `knowledge`. /// arithmetic and loop facts recorded in `knowledge`.
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value); llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {});
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge); llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value);
/// Follows alias, view, and DPS chains to recover the backing value of a /// Follows alias, view, and DPS chains to recover the backing value of a
/// loop-carried memref/result. /// loop-carried memref/result.
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge); mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value);
} // namespace onnx_mlir } // namespace onnx_mlir
+182
View File
@@ -0,0 +1,182 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Matchers.h"
#include "AffineUtils.hpp"
#include "ConstantUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
static FailureOr<int64_t> floorDivSigned(int64_t lhs, int64_t rhs) {
if (rhs <= 0)
return failure();
int64_t quotient = lhs / rhs;
int64_t remainder = lhs % rhs;
if (remainder != 0 && lhs < 0)
--quotient;
return quotient;
}
static FailureOr<int64_t> ceilDivSigned(int64_t lhs, int64_t rhs) {
if (rhs <= 0)
return failure();
int64_t quotient = lhs / rhs;
int64_t remainder = lhs % rhs;
if (remainder != 0 && lhs > 0)
++quotient;
return quotient;
}
Value createOrFoldAffineApply(
RewriterBase& rewriter, Location loc, AffineMap map, ValueRange operands, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
assert(map.getNumResults() == 1 && "affine.apply expects a single-result affine map");
SmallVector<Attribute> operandConstants;
operandConstants.reserve(operands.size());
for (Value operand : operands) {
std::optional<int64_t> constantValue = matchConstantIndexValue(operand);
if (!constantValue)
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
operandConstants.push_back(rewriter.getIndexAttr(*constantValue));
}
SmallVector<Attribute> foldedResults;
if (succeeded(map.constantFold(operandConstants, foldedResults)) && foldedResults.size() == 1)
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
return getOrCreateIndexConstant(rewriter, constantAnchor, constantResult.getInt());
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
}
Value createOrFoldAffineApply(
RewriterBase& rewriter, Location loc, AffineExpr expr, ValueRange dims, Operation* constantAnchor) {
AffineMap map = AffineMap::get(/*dimCount=*/dims.size(), /*symbolCount=*/0, expr);
return createOrFoldAffineApply(rewriter, loc, map, dims, constantAnchor);
}
Value affineMulConst(RewriterBase& rewriter, Location loc, Value value, int64_t multiplier, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
if (multiplier == 0)
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
if (multiplier == 1)
return value;
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
return createOrFoldAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value}, constantAnchor);
}
Value affineModConst(RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
assert(divisor > 0 && "expected a positive affine.mod divisor");
if (divisor == 1)
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
return createOrFoldAffineApply(rewriter, loc, d0 % divisor, ValueRange {value}, constantAnchor);
}
Value affineFloorDivConst(
RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
assert(divisor > 0 && "expected a positive affine.floor_div divisor");
if (divisor == 1)
return value;
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
}
FailureOr<int64_t> evaluateAffineExpr(AffineExpr expr, ArrayRef<int64_t> dims, ArrayRef<int64_t> symbols) {
if (auto constant = dyn_cast<AffineConstantExpr>(expr))
return constant.getValue();
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
unsigned position = dim.getPosition();
if (position >= dims.size())
return failure();
return dims[position];
}
if (auto symbol = dyn_cast<AffineSymbolExpr>(expr)) {
unsigned position = symbol.getPosition();
if (position >= symbols.size())
return failure();
return symbols[position];
}
auto binary = dyn_cast<AffineBinaryOpExpr>(expr);
if (!binary)
return failure();
FailureOr<int64_t> lhs = evaluateAffineExpr(binary.getLHS(), dims, symbols);
FailureOr<int64_t> rhs = evaluateAffineExpr(binary.getRHS(), dims, symbols);
if (failed(lhs) || failed(rhs))
return failure();
switch (binary.getKind()) {
case AffineExprKind::Add: return *lhs + *rhs;
case AffineExprKind::Mul: return *lhs * *rhs;
case AffineExprKind::FloorDiv: return floorDivSigned(*lhs, *rhs);
case AffineExprKind::CeilDiv: return ceilDivSigned(*lhs, *rhs);
case AffineExprKind::Mod: {
FailureOr<int64_t> div = floorDivSigned(*lhs, *rhs);
if (failed(div))
return failure();
return *lhs - *div * *rhs;
}
default: return failure();
}
}
FailureOr<int64_t> evaluateSingleResultAffineMap(AffineMap map, ArrayRef<int64_t> operands) {
if (map.getNumResults() != 1 || operands.size() != map.getNumInputs())
return failure();
ArrayRef<int64_t> dims(operands.data(), map.getNumDims());
ArrayRef<int64_t> symbols(operands.data() + map.getNumDims(), map.getNumSymbols());
return evaluateAffineExpr(map.getResult(0), dims, symbols);
}
FailureOr<int64_t> evaluateAffineApply(affine::AffineApplyOp affineApply, IndexValueResolver resolver) {
SmallVector<int64_t, 4> operands;
operands.reserve(affineApply.getMapOperands().size());
for (Value operand : affineApply.getMapOperands()) {
FailureOr<int64_t> folded = resolver(operand);
if (failed(folded))
return failure();
operands.push_back(*folded);
}
return evaluateSingleResultAffineMap(affineApply.getAffineMap(), operands);
}
bool isSingleResultSymbolFreeAffineMap(AffineMap map) { return map.getNumResults() == 1 && map.getNumSymbols() == 0; }
bool isDimAndConstantAffineExpr(AffineExpr expr) {
switch (expr.getKind()) {
case AffineExprKind::Constant:
case AffineExprKind::DimId: return true;
case AffineExprKind::SymbolId: return false;
case AffineExprKind::Add: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDimAndConstantAffineExpr(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS());
}
case AffineExprKind::Mul: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return (isa<AffineConstantExpr>(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS()))
|| (isa<AffineConstantExpr>(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS()));
}
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isa<AffineConstantExpr>(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS());
}
}
llvm_unreachable("unexpected affine expression kind");
}
} // namespace onnx_mlir
+55
View File
@@ -0,0 +1,55 @@
#pragma once
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/FunctionExtras.h"
namespace onnx_mlir {
using IndexValueResolver = llvm::function_ref<llvm::FailureOr<int64_t>(mlir::Value)>;
mlir::Value createOrFoldAffineApply(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::AffineMap map,
mlir::ValueRange operands,
mlir::Operation* constantAnchor);
mlir::Value createOrFoldAffineApply(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::AffineExpr expr,
mlir::ValueRange dims,
mlir::Operation* constantAnchor);
mlir::Value affineMulConst(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::Value value,
int64_t multiplier,
mlir::Operation* constantAnchor);
mlir::Value affineModConst(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::Value value,
int64_t divisor,
mlir::Operation* constantAnchor);
mlir::Value affineFloorDivConst(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::Value value,
int64_t divisor,
mlir::Operation* constantAnchor);
llvm::FailureOr<int64_t>
evaluateAffineExpr(mlir::AffineExpr expr, llvm::ArrayRef<int64_t> dims, llvm::ArrayRef<int64_t> symbols = {});
llvm::FailureOr<int64_t> evaluateSingleResultAffineMap(mlir::AffineMap map, llvm::ArrayRef<int64_t> operands);
llvm::FailureOr<int64_t> evaluateAffineApply(mlir::affine::AffineApplyOp affineApply, IndexValueResolver resolver);
bool isSingleResultSymbolFreeAffineMap(mlir::AffineMap map);
bool isDimAndConstantAffineExpr(mlir::AffineExpr expr);
} // namespace onnx_mlir
+32
View File
@@ -0,0 +1,32 @@
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
namespace onnx_mlir {
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<mlir::DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
llvm::SmallVector<int32_t> laneCoreIds;
laneCoreIds.reserve(coreIds.size() / laneCount);
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
return laneCoreIds;
}
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex) {
if (mlir::isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 3;
if (mlir::isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
}
bool isExplicitDevToHostTargetOperand(mlir::Operation* op, unsigned operandIndex) {
return mlir::isa<pim::PimMemCopyDevToHostOp>(op) && operandIndex == 2;
}
} // namespace onnx_mlir
+18
View File
@@ -0,0 +1,18 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex);
bool isExplicitDevToHostTargetOperand(mlir::Operation* op, unsigned operandIndex);
} // namespace onnx_mlir
+157
View File
@@ -0,0 +1,157 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "ConstantUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
static std::optional<int64_t> getIndexConstantValue(arith::ConstantOp constantOp) {
if (!constantOp.getType().isIndex())
return std::nullopt;
auto intAttr = dyn_cast<IntegerAttr>(constantOp.getValue());
if (!intAttr || !intAttr.getType().isIndex())
return std::nullopt;
return intAttr.getInt();
}
Block* getConstantInsertionBlock(Operation* anchorOp) {
assert(anchorOp && "expected a valid anchor operation");
if (auto funcOp = dyn_cast<func::FuncOp>(anchorOp))
return &funcOp.getBody().front();
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
return &funcOp.getBody().front();
if (auto moduleOp = dyn_cast<ModuleOp>(anchorOp))
return moduleOp.getBody();
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
return moduleOp.getBody();
return anchorOp->getBlock();
}
Value getOrCreateConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) {
assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getConstantInsertionBlock(anchorOp);
for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
continue;
return constantOp.getResult();
}
auto* arithDialect = anchorOp->getContext()->getOrLoadDialect<arith::ArithDialect>();
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
}
Value getOrCreateConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) {
assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getConstantInsertionBlock(anchorOp);
for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
continue;
return constantOp.getResult();
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(hostBlock);
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
}
Value getOrCreateConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) {
return getOrCreateConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType());
}
Value getOrCreateIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) {
Builder builder(anchorOp->getContext());
return getOrCreateConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
}
Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) {
Builder builder(anchorOp->getContext());
return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
}
void hoistAndUniquifyIndexConstants(func::FuncOp funcOp, RewriterBase& rewriter) {
if (funcOp.getBody().empty())
return;
Block& entryBlock = funcOp.getBody().front();
DenseMap<int64_t, Value> canonicalByValue;
SmallVector<arith::ConstantOp> constants;
funcOp.walk([&](arith::ConstantOp constantOp) {
if (!getIndexConstantValue(constantOp))
return;
constants.push_back(constantOp);
});
for (arith::ConstantOp constantOp : constants) {
auto value = getIndexConstantValue(constantOp);
if (!value || constantOp->getBlock() != &entryBlock)
continue;
canonicalByValue.try_emplace(*value, constantOp.getResult());
}
for (arith::ConstantOp constantOp : constants) {
auto value = getIndexConstantValue(constantOp);
if (!value)
continue;
Value canonical = canonicalByValue.lookup(*value);
if (!canonical) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&entryBlock);
Builder builder(funcOp.getContext());
canonical =
arith::ConstantOp::create(rewriter, constantOp.getLoc(), builder.getIndexType(), builder.getIndexAttr(*value));
canonicalByValue[*value] = canonical;
}
if (constantOp.getResult() == canonical)
continue;
constantOp.getResult().replaceAllUsesWith(canonical);
}
for (arith::ConstantOp constantOp : llvm::reverse(constants)) {
auto value = getIndexConstantValue(constantOp);
if (!value)
continue;
if (constantOp.getResult() == canonicalByValue.lookup(*value))
continue;
if (constantOp.use_empty())
rewriter.eraseOp(constantOp);
}
}
std::optional<int64_t> matchConstantIndexValue(Value value) {
if (!value || !value.getType().isIndex())
return std::nullopt;
if (auto constant = value.getDefiningOp<arith::ConstantIndexOp>())
return constant.value();
if (auto constant = value.getDefiningOp<arith::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constant.getValue()); intAttr && intAttr.getType().isIndex())
return intAttr.getInt();
return std::nullopt;
}
std::optional<int64_t> matchConstantIndexValue(OpFoldResult value) {
if (auto attr = dyn_cast<Attribute>(value))
if (auto intAttr = dyn_cast<IntegerAttr>(attr); intAttr && intAttr.getType().isIndex())
return intAttr.getInt();
if (auto operand = dyn_cast<Value>(value))
return matchConstantIndexValue(operand);
return std::nullopt;
}
} // namespace onnx_mlir
+33
View File
@@ -0,0 +1,33 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/FoldUtils.h"
#include <optional>
namespace onnx_mlir {
mlir::Block* getConstantInsertionBlock(mlir::Operation* anchorOp);
mlir::Value
getOrCreateConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, mlir::Attribute value, mlir::Type type);
mlir::Value
getOrCreateConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, mlir::Attribute value, mlir::Type type);
mlir::Value getOrCreateConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp);
mlir::Value getOrCreateIndexConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value);
mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value);
void hoistAndUniquifyIndexConstants(mlir::func::FuncOp funcOp, mlir::RewriterBase& rewriter);
std::optional<int64_t> matchConstantIndexValue(mlir::Value value);
std::optional<int64_t> matchConstantIndexValue(mlir::OpFoldResult value);
} // namespace onnx_mlir
+82 -12
View File
@@ -1,24 +1,37 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" #include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
bool isCoreStaticAddressOp(mlir::Operation* op) { bool isCoreStaticAddressOp(mlir::Operation* op) {
return mlir::isa<mlir::arith::ConstantOp, if (mlir::isa<mlir::arith::ConstantOp,
mlir::arith::AddIOp, mlir::arith::AddIOp,
mlir::arith::SubIOp, mlir::arith::SubIOp,
mlir::arith::MulIOp, mlir::arith::MulIOp,
mlir::arith::DivUIOp, mlir::arith::DivUIOp,
mlir::arith::RemUIOp, mlir::arith::DivSIOp,
mlir::arith::IndexCastOp, mlir::arith::MinUIOp,
mlir::memref::AllocOp, mlir::arith::RemUIOp,
mlir::memref::SubViewOp, mlir::arith::RemSIOp,
mlir::memref::CastOp, mlir::arith::IndexCastOp,
mlir::memref::CollapseShapeOp, mlir::arith::CmpIOp,
mlir::memref::ExpandShapeOp>(op); mlir::memref::AllocOp,
mlir::memref::SubViewOp,
mlir::memref::CastOp,
mlir::memref::CollapseShapeOp,
mlir::memref::ExpandShapeOp>(op))
return true;
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(op))
return selectOp.getType().isIntOrIndex();
return false;
} }
mlir::LogicalResult mlir::LogicalResult
@@ -29,6 +42,9 @@ walkPimCoreBlock(mlir::Block& block,
for (mlir::Operation& op : block) { for (mlir::Operation& op : block) {
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op)) if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue; continue;
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
continue;
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) { if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
mlir::Block& loopBody = forOp.getRegion().front(); mlir::Block& loopBody = forOp.getRegion().front();
@@ -64,4 +80,58 @@ walkPimCoreBlock(mlir::Block& block,
return mlir::success(!hasFailure); return mlir::success(!hasFailure);
} }
mlir::LogicalResult walkPimCoreBlockStructurally(
mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
bool hasFailure = false;
for (mlir::Operation& op : block) {
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
continue;
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
mlir::Block& loopBody = forOp.getRegion().front();
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
auto step = resolveIndexValue(forOp.getStep(), knowledge);
if (failed(lowerBound) || failed(upperBound) || failed(step)) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM verification");
hasFailure = true;
continue;
}
if (*step <= 0) {
forOp.emitOpError("requires positive scf.for step for PIM verification");
hasFailure = true;
continue;
}
llvm::SmallVector<int64_t, 2> samples;
if (*lowerBound < *upperBound) {
samples.push_back(*lowerBound);
int64_t last = *lowerBound + ((*upperBound - 1 - *lowerBound) / *step) * *step;
if (last != *lowerBound)
samples.push_back(last);
}
for (int64_t inductionValue : samples) {
StaticValueKnowledge loopKnowledge = knowledge;
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), forOp.getInitArgs()))
loopKnowledge.aliases[iterArg] = iterValue;
if (failed(walkPimCoreBlockStructurally(loopBody, loopKnowledge, callback)))
hasFailure = true;
}
continue;
}
if (failed(callback(op, knowledge)))
hasFailure = true;
}
return mlir::success(!hasFailure);
}
} // namespace onnx_mlir } // namespace onnx_mlir
+8
View File
@@ -21,4 +21,12 @@ walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge, const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback); llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
/// Walks a `pim.core`-like body structurally for verification without
/// enumerating full loop trip counts. Loop bounds must still be statically
/// evaluable so address resolution remains well-defined.
mlir::LogicalResult walkPimCoreBlockStructurally(
mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
} // namespace onnx_mlir } // namespace onnx_mlir
+96
View File
@@ -0,0 +1,96 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "llvm/Support/MathExtras.h"
#include <optional>
#include "ConstantUtils.hpp"
#include "LoopUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static std::optional<int64_t> getStaticTripCount(Value lowerBound, Value upperBound, Value step) {
auto lower = matchConstantIndexValue(lowerBound);
auto upper = matchConstantIndexValue(upperBound);
auto stepValue = matchConstantIndexValue(step);
if (!lower || !upper || !stepValue)
return std::nullopt;
if (*stepValue <= 0)
return std::nullopt;
if (*upper <= *lower)
return int64_t {0};
return llvm::divideCeil(*upper - *lower, *stepValue);
}
} // namespace
static LogicalResult validateNormalizedLoopYields(Location loc, ValueRange initArgs, ArrayRef<Value> yieldedValues) {
if (yieldedValues.size() == initArgs.size())
return success();
emitError(loc) << "normalized loop body yielded " << yieldedValues.size() << " values for " << initArgs.size()
<< " iter args";
return failure();
}
FailureOr<NormalizedLoopResult> buildNormalizedScfFor(OpBuilder& builder,
Location loc,
Value lowerBound,
Value upperBound,
Value step,
ValueRange initArgs,
NormalizedLoopBodyBuilder bodyBuilder) {
NormalizedLoopResult result;
if (auto stepValue = matchConstantIndexValue(step); stepValue && *stepValue <= 0) {
emitError(loc) << "normalized scf.for requires a positive step, got " << *stepValue;
return failure();
}
if (auto tripCount = getStaticTripCount(lowerBound, upperBound, step)) {
if (*tripCount == 0) {
llvm::append_range(result.results, initArgs);
return result;
}
if (*tripCount == 1) {
result.inductionVar = lowerBound;
if (failed(bodyBuilder(builder, loc, lowerBound, initArgs, result.results)))
return failure();
if (failed(validateNormalizedLoopYields(loc, initArgs, result.results)))
return failure();
return result;
}
}
result.loop = scf::ForOp::create(builder, loc, lowerBound, upperBound, step, initArgs);
result.inductionVar = result.loop.getInductionVar();
{
OpBuilder::InsertionGuard guard(builder);
Block* body = result.loop.getBody();
if (!body->empty())
if (auto yieldOp = dyn_cast<scf::YieldOp>(body->back()))
yieldOp->erase();
builder.setInsertionPointToEnd(body);
ValueRange iterArgs = result.loop.getRegionIterArgs();
if (failed(bodyBuilder(builder, loc, result.inductionVar, iterArgs, result.results))) {
result.loop.erase();
return failure();
}
if (failed(validateNormalizedLoopYields(loc, initArgs, result.results))) {
result.loop.erase();
return failure();
}
scf::YieldOp::create(builder, loc, result.results);
}
builder.setInsertionPointAfter(result.loop);
result.results.assign(result.loop.getResults().begin(), result.loop.getResults().end());
return result;
}
} // namespace onnx_mlir
+30
View File
@@ -0,0 +1,30 @@
#pragma once
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
namespace onnx_mlir {
struct NormalizedLoopResult {
mlir::Value inductionVar;
llvm::SmallVector<mlir::Value, 4> results;
mlir::scf::ForOp loop;
bool wasInlined() const { return !loop; }
};
using NormalizedLoopBodyBuilder = llvm::function_ref<mlir::LogicalResult(
mlir::OpBuilder&, mlir::Location, mlir::Value, mlir::ValueRange, llvm::SmallVectorImpl<mlir::Value>&)>;
mlir::FailureOr<NormalizedLoopResult> buildNormalizedScfFor(mlir::OpBuilder& builder,
mlir::Location loc,
mlir::Value lowerBound,
mlir::Value upperBound,
mlir::Value step,
mlir::ValueRange initArgs,
NormalizedLoopBodyBuilder bodyBuilder);
} // namespace onnx_mlir
+77
View File
@@ -1,4 +1,5 @@
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
@@ -35,6 +36,30 @@ int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
return numElements; return numElements;
} }
bool hasByteSizedElementType(mlir::Type elementType) {
if (mlir::isa<mlir::IndexType>(elementType))
return true;
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
return intType.getWidth() > 0 && intType.getWidth() % 8 == 0;
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
return floatType.getWidth() > 0 && floatType.getWidth() % 8 == 0;
return false;
}
size_t getElementTypeSizeInBytes(mlir::Type elementType) {
if (mlir::isa<mlir::IndexType>(elementType))
return mlir::IndexType::kInternalStorageBitWidth / 8;
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
return static_cast<size_t>(intType.getWidth() / 8);
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
return static_cast<size_t>(floatType.getWidth() / 8);
llvm_unreachable("expected byte-sized integer, float, or index element type");
}
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType) {
return static_cast<size_t>(shapedType.getNumElements()) * getElementTypeSizeInBytes(shapedType.getElementType());
}
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape, bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets, llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes, llvm::ArrayRef<int64_t> sizes,
@@ -86,4 +111,56 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
return true; return true;
} }
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
llvm::ArrayRef<int64_t> staticSizes,
llvm::ArrayRef<int64_t> staticStrides) {
if (sourceShape.size() != mixedOffsets.size() || sourceShape.size() != staticSizes.size()
|| sourceShape.size() != staticStrides.size()) {
return false;
}
if (llvm::any_of(staticStrides, [](int64_t stride) { return stride != 1; }))
return false;
auto reversedTriples =
llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(mixedOffsets), llvm::reverse(staticSizes));
auto firstNonZeroOrDynamicOffset = llvm::find_if(reversedTriples, [](auto triple) {
auto [_sourceDim, offset, _size] = triple;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
return mlir::cast<mlir::IntegerAttr>(attr).getInt() != 0;
return true;
});
if (firstNonZeroOrDynamicOffset != reversedTriples.end()) {
auto [sourceDim, offset, size] = *firstNonZeroOrDynamicOffset;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
int64_t staticOffset = mlir::cast<mlir::IntegerAttr>(attr).getInt();
if (size > sourceDim - staticOffset)
return false;
}
++firstNonZeroOrDynamicOffset;
for (auto it = firstNonZeroOrDynamicOffset; it != reversedTriples.end(); ++it)
if (std::get<2>(*it) != 1)
return false;
}
auto reversedSizes = llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(staticSizes));
auto firstDifferentSize = llvm::find_if(reversedSizes, [](auto pair) {
auto [sourceDim, size] = pair;
return size != sourceDim;
});
if (firstDifferentSize != reversedSizes.end()) {
++firstDifferentSize;
for (auto it = firstDifferentSize; it != reversedSizes.end(); ++it)
if (std::get<1>(*it) != 1)
return false;
}
return true;
}
} // namespace onnx_mlir } // namespace onnx_mlir
+17
View File
@@ -1,8 +1,14 @@
#pragma once #pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <cstddef>
namespace onnx_mlir { namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape); llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
@@ -14,9 +20,20 @@ int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t>
int64_t getNumElements(llvm::ArrayRef<int64_t> shape); int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool hasByteSizedElementType(mlir::Type elementType);
size_t getElementTypeSizeInBytes(mlir::Type elementType);
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape, bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets, llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes, llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides); llvm::ArrayRef<int64_t> strides);
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
llvm::ArrayRef<int64_t> staticSizes,
llvm::ArrayRef<int64_t> staticStrides);
} // namespace onnx_mlir } // namespace onnx_mlir
+23 -2
View File
@@ -1,7 +1,6 @@
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir; using namespace mlir;
@@ -32,6 +31,19 @@ Value stripMemRefViewOps(Value value) {
} }
} }
Value stripMemRefAddressingOps(Value value) {
while (true) {
if (auto subviewOp = value.getDefiningOp<memref::SubViewOp>()) {
value = subviewOp.getSource();
continue;
}
Value strippedValue = stripMemRefViewOps(value);
if (strippedValue == value)
return value;
value = strippedValue;
}
}
bool hasAllStaticSubviewParts(memref::SubViewOp subview) { bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); }) return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); }) && llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
@@ -82,4 +94,13 @@ FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo&
return staticOffsets; return staticOffsets;
} }
bool isMemRefBaseAddressableValue(Value value) {
value = stripMemRefAddressingOps(value);
if (isa<BlockArgument>(value))
return true;
Operation* defOp = value.getDefiningOp();
return defOp && isa<memref::AllocOp, memref::GetGlobalOp>(defOp);
}
} // namespace onnx_mlir } // namespace onnx_mlir
+4
View File
@@ -20,6 +20,8 @@ mlir::Value stripMemRefCasts(mlir::Value value);
mlir::Value stripMemRefViewOps(mlir::Value value); mlir::Value stripMemRefViewOps(mlir::Value value);
mlir::Value stripMemRefAddressingOps(mlir::Value value);
bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview); bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview);
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value); llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
@@ -27,4 +29,6 @@ llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic. /// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info); llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
bool isMemRefBaseAddressableValue(mlir::Value value);
} // namespace onnx_mlir } // namespace onnx_mlir
+233 -25
View File
@@ -1,8 +1,14 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallSet.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp" #include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -19,29 +25,67 @@ void markWeightAlways(mlir::Operation* op) {
namespace { namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy> CompiledIndexExpr makeConstantExpr(int64_t constant) {
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = constant;
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::move(expr)));
}
CompiledIndexExpr makeBinaryExpr(CompiledIndexExprNode::Kind kind, CompiledIndexExpr lhs, CompiledIndexExpr rhs) {
CompiledIndexExprNode expr;
expr.kind = kind;
expr.operands = {std::move(lhs), std::move(rhs)};
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::move(expr)));
}
CompiledIndexExpr addExpr(CompiledIndexExpr lhs, CompiledIndexExpr rhs) {
return makeBinaryExpr(CompiledIndexExprNode::Kind::Add, std::move(lhs), std::move(rhs));
}
CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
}
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefTypeStrides(mlir::MemRefType type) {
llvm::SmallVector<int64_t> strides;
int64_t offset = 0;
if (failed(type.getStridesAndOffset(strides, offset)))
return mlir::failure();
if (llvm::is_contained(strides, mlir::ShapedType::kDynamic))
return mlir::failure();
return strides;
}
template <typename VMMOpTy, typename ParentOpTy>
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
auto weightArg = parentOp.getWeightArgument(weightIndex);
if (!weightArg)
return false;
bool found = false; bool found = false;
parentOp.walk([&](mlir::Operation* op) { parentOp.walk([&](mlir::Operation* op) {
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op)) if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= mvmOp.getWeightIndex() == weightIndex; found |= vmmOp.getWeight() == *weightArg;
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeightIndex() == weightIndex;
}); });
return found; return found;
} }
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy> template <typename VMMOpTy, typename ParentOpTy>
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) { void walkVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
auto weights = parentOp.getWeights(); auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited; llvm::SmallSet<unsigned, 8> visited;
auto walkWeightIndex = [&](unsigned weightIndex) { auto walkWeight = [&](mlir::Value weight) {
if (weightIndex < weights.size() && visited.insert(weightIndex).second) for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
callback(parentOp->getOpOperand(weightIndex)); auto weightArg = parentOp.getWeightArgument(weightIndex);
if (!weightArg || *weightArg != weight)
continue;
if (visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
break;
}
}; };
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); }); parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
} }
} // namespace } // namespace
@@ -54,7 +98,7 @@ bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
if (!computeOp || operandIndex >= computeOp.getWeights().size()) if (!computeOp || operandIndex >= computeOp.getWeights().size())
return false; return false;
return hasMvmVmmWeightUse<spatial::SpatMVMOp, spatial::SpatVMMOp>(computeOp, operandIndex); return hasVmmWeightUse<spatial::SpatVMMOp>(computeOp, operandIndex);
} }
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) { bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
@@ -76,8 +120,8 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self); return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user)) if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self); return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(user)) if (auto transposeOp = mlir::dyn_cast<mlir::linalg::TransposeOp>(user))
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self); return transposeOp.getInput() == currentValue && self(transposeOp.getResult()[0], self);
return false; return false;
}); });
@@ -90,19 +134,183 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
assert(root && "expected valid root op"); assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) { root->walk([&](pim::PimCoreOp coreOp) {
coreOp.walk([&](pim::PimVMMOp vmmOp) { coreOp.walk([&](pim::PimVMMOp vmmOp) {
auto weights = coreOp.getWeights(); if (auto weightIndex = resolveWeightIndex(coreOp.getOperation(), vmmOp.getWeight()))
unsigned weightIndex = vmmOp.getWeightIndex(); callback(coreOp->getOpOperand(*weightIndex));
if (weightIndex < weights.size())
callback(coreOp->getOpOperand(weightIndex));
}); });
}); });
root->walk([&](pim::PimCoreBatchOp coreBatchOp) { root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
auto weights = coreBatchOp.getWeights(); coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
for (auto weight : weights) if (auto weightIndex = resolveWeightIndex(coreBatchOp.getOperation(), vmmOp.getWeight()))
for (mlir::OpOperand& use : weight.getUses()) callback(coreBatchOp->getOpOperand(*weightIndex));
if (use.getOwner() == coreBatchOp.getOperation()) });
callback(use);
}); });
} }
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) {
weight = stripMemRefAddressingOps(weight);
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
if (coreOp.getWeightArgument(weightIndex) == weight)
return weightIndex;
return std::nullopt;
}
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
if (coreBatchOp.getWeightArgument(weightIndex) == weight)
return weightIndex;
return std::nullopt;
}
return std::nullopt;
}
llvm::FailureOr<ResolvedWeightView>
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge) {
llvm::SmallVector<mlir::Operation*> viewOps;
mlir::Value current = weight;
while (true) {
if (mlir::Value directAlias = knowledge.aliases.lookup(current); directAlias && directAlias != current) {
current = directAlias;
continue;
}
if (auto defOp = current.getDefiningOp()) {
if (auto getGlobalOp = mlir::dyn_cast<mlir::memref::GetGlobalOp>(defOp)) {
auto moduleOp = weightOwner ? weightOwner->getParentOfType<mlir::ModuleOp>() : mlir::ModuleOp {};
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return mlir::failure();
ResolvedWeightView view;
view.globalOp = globalOp;
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStrides(view.shape);
CompiledIndexExpr offsetExpr = makeConstantExpr(0);
for (mlir::Operation* viewOp : llvm::reverse(viewOps)) {
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(viewOp)) {
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) {
CompiledIndexExpr offsetValue = makeConstantExpr(0);
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
if (!intAttr)
return mlir::failure();
offsetValue = makeConstantExpr(intAttr.getInt());
}
else if (auto value = mlir::dyn_cast<mlir::Value>(offset)) {
auto compiledOffset = compileIndexExpr(value);
if (failed(compiledOffset))
return mlir::failure();
offsetValue = *compiledOffset;
}
else {
return mlir::failure();
}
offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride));
}
auto resultType = mlir::cast<mlir::MemRefType>(subview.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = std::move(*resultStrides);
continue;
}
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(viewOp)) {
auto resultType = mlir::cast<mlir::MemRefType>(collapse.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = std::move(*resultStrides);
continue;
}
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(viewOp)) {
auto resultType = mlir::cast<mlir::MemRefType>(expand.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = std::move(*resultStrides);
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(viewOp)) {
auto resultType = mlir::cast<mlir::MemRefType>(castOp.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = std::move(*resultStrides);
continue;
}
return mlir::failure();
}
auto resolvedOffset = offsetExpr.evaluate(knowledge);
if (failed(resolvedOffset))
return mlir::failure();
view.offset = *resolvedOffset;
return view;
}
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp)) {
viewOps.push_back(defOp);
current = subview.getSource();
continue;
}
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp)) {
viewOps.push_back(defOp);
current = collapse.getSrc();
continue;
}
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(defOp)) {
viewOps.push_back(defOp);
current = expand.getSrc();
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(defOp)) {
viewOps.push_back(defOp);
current = castOp.getSource();
continue;
}
return mlir::failure();
}
if (mlir::Value loopAlias = resolveLoopCarriedAlias(current, knowledge); loopAlias && loopAlias != current) {
current = loopAlias;
continue;
}
auto weightIndex = resolveWeightIndex(weightOwner, current);
if (!weightIndex)
return mlir::failure();
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
current = coreOp.getWeights()[*weightIndex];
continue;
}
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
current = coreBatchOp.getWeights()[*weightIndex];
continue;
}
return mlir::failure();
}
}
} // namespace onnx_mlir } // namespace onnx_mlir
+36 -1
View File
@@ -1,15 +1,34 @@
#pragma once #pragma once
#include "mlir/IR/Operation.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include <optional>
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways"; inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir { namespace onnx_mlir {
struct ResolvedWeightView {
mlir::memref::GlobalOp globalOp;
llvm::SmallVector<int64_t> shape;
llvm::SmallVector<int64_t> strides;
int64_t offset = 0;
bool operator==(const ResolvedWeightView& other) const {
return globalOp == other.globalOp && shape == other.shape && strides == other.strides && offset == other.offset;
}
};
bool hasWeightAlways(mlir::Operation* op); bool hasWeightAlways(mlir::Operation* op);
/// Tags an op as producing a value that should stay materialized as a reusable /// Tags an op as producing a value that should stay materialized as a reusable
@@ -26,4 +45,20 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
/// passes can identify globals that must remain weight-backed. /// passes can identify globals that must remain weight-backed.
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback); void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight);
llvm::FailureOr<ResolvedWeightView>
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge = {});
template <typename CoreLikeOpTy>
llvm::SmallVector<unsigned, 8> getUsedWeightIndices(CoreLikeOpTy coreLikeOp) {
llvm::SmallVector<unsigned, 8> indices;
coreLikeOp.walk([&](pim::PimVMMOp vmmOp) {
auto weightIndex = resolveWeightIndex(coreLikeOp.getOperation(), vmmOp.getWeight());
if (weightIndex && !llvm::is_contained(indices, *weightIndex))
indices.push_back(*weightIndex);
});
llvm::sort(indices);
return indices;
}
} // namespace onnx_mlir } // namespace onnx_mlir
+1
View File
@@ -12,6 +12,7 @@
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp" #include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp" #include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp" #include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
@@ -0,0 +1,222 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
namespace onnx_mlir::pim {
namespace {
static void emitCrashMessage(llvm::StringRef fieldName, llvm::StringRef message) {
llvm::errs() << "PIM " << fieldName << " " << message << "\n";
}
template <typename To, typename From>
static FailureOr<To> checkedCastAtLocation(From value, Location loc, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<To> && std::is_integral_v<From>, "checkedCastAtLocation requires integral types");
using ToLimits = std::numeric_limits<To>;
if constexpr (std::is_signed_v<From> == std::is_signed_v<To>) {
if (value < static_cast<From>(ToLimits::min()) || value > static_cast<From>(ToLimits::max())) {
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
return failure();
}
}
else if constexpr (std::is_signed_v<From>) {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::make_unsigned_t<To>;
if (value < 0 || static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
return failure();
}
}
else {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::conditional_t<std::is_signed_v<To>, std::make_unsigned_t<To>, To>;
if (static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
return failure();
}
}
return static_cast<To>(value);
}
template <typename UInt>
FailureOr<UInt> checkedMulAtLocation(UInt lhs, UInt rhs, Location loc, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>,
"checkedMulAtLocation requires unsigned integral types");
if (lhs != 0 && rhs > std::numeric_limits<UInt>::max() / lhs) {
emitCheckedArithmeticError(loc, fieldName, "multiplication overflow");
return failure();
}
return lhs * rhs;
}
} // namespace
InFlightDiagnostic emitCheckedArithmeticError(Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message) {
assert(anchor && "expected arithmetic diagnostics to have an anchor op");
return anchor->emitOpError() << fieldName << " " << message;
}
InFlightDiagnostic emitCheckedArithmeticError(Location loc, llvm::StringRef fieldName, llvm::StringRef message) {
return emitError(loc) << "PIM " << fieldName << " " << message;
}
FailureOr<int32_t> checkedI32(int64_t value, Operation* anchor, llvm::StringRef fieldName) {
return checkedCast<int32_t>(value, anchor, fieldName);
}
FailureOr<int32_t> checkedI32(uint64_t value, Operation* anchor, llvm::StringRef fieldName) {
return checkedCast<int32_t>(value, anchor, fieldName);
}
FailureOr<uint8_t> checkedU8(uint64_t value, Operation* anchor, llvm::StringRef fieldName) {
return checkedCast<uint8_t>(value, anchor, fieldName);
}
FailureOr<size_t> checkedSize(int64_t value, Operation* anchor, llvm::StringRef fieldName) {
return checkedCast<size_t>(value, anchor, fieldName);
}
FailureOr<IntegerAttr>
getCheckedI32Attr(Builder& builder, Operation* anchor, int64_t value, llvm::StringRef fieldName) {
assert(anchor && "checked op-based attrs require a non-null diagnostic anchor");
auto checkedValue = checkedI32(value, anchor, fieldName);
if (failed(checkedValue))
return failure();
return builder.getI32IntegerAttr(*checkedValue);
}
FailureOr<IntegerAttr>
getCheckedI32Attr(Builder& builder, Operation* anchor, uint64_t value, llvm::StringRef fieldName) {
assert(anchor && "checked op-based attrs require a non-null diagnostic anchor");
auto checkedValue = checkedI32(value, anchor, fieldName);
if (failed(checkedValue))
return failure();
return builder.getI32IntegerAttr(*checkedValue);
}
FailureOr<IntegerAttr> getCheckedI32Attr(Builder& builder, Location loc, int64_t value, llvm::StringRef fieldName) {
auto checkedValue = checkedCastAtLocation<int32_t>(value, loc, fieldName);
if (failed(checkedValue))
return failure();
return builder.getI32IntegerAttr(*checkedValue);
}
FailureOr<IntegerAttr> getCheckedI32Attr(Builder& builder, Location loc, uint64_t value, llvm::StringRef fieldName) {
auto checkedValue = checkedCastAtLocation<int32_t>(value, loc, fieldName);
if (failed(checkedValue))
return failure();
return builder.getI32IntegerAttr(*checkedValue);
}
FailureOr<uint64_t> getCheckedShapedTypeSizeInBytes(ShapedType type, Operation* anchor, llvm::StringRef fieldName) {
assert(anchor && "checked op-based size helpers require a non-null diagnostic anchor");
if (!type.hasStaticShape()) {
emitCheckedArithmeticError(anchor, fieldName, "requires static shaped type");
return failure();
}
if (!hasByteSizedElementType(type.getElementType())) {
emitCheckedArithmeticError(anchor, fieldName, "requires byte-sized element type");
return failure();
}
uint64_t elements = 1;
for (int64_t dim : type.getShape()) {
if (dim < 0) {
emitCheckedArithmeticError(anchor, fieldName, "requires nonnegative dimensions");
return failure();
}
auto nextElements = checkedMul(elements, static_cast<uint64_t>(dim), anchor, fieldName);
if (failed(nextElements))
return failure();
elements = *nextElements;
}
return checkedMul(
elements, static_cast<uint64_t>(getElementTypeSizeInBytes(type.getElementType())), anchor, fieldName);
}
FailureOr<uint64_t> getCheckedShapedTypeSizeInBytes(ShapedType type, Location loc, llvm::StringRef fieldName) {
if (!type.hasStaticShape()) {
emitCheckedArithmeticError(loc, fieldName, "requires static shaped type");
return failure();
}
if (!hasByteSizedElementType(type.getElementType())) {
emitCheckedArithmeticError(loc, fieldName, "requires byte-sized element type");
return failure();
}
uint64_t elements = 1;
for (int64_t dim : type.getShape()) {
if (dim < 0) {
emitCheckedArithmeticError(loc, fieldName, "requires nonnegative dimensions");
return failure();
}
auto nextElements = checkedMulAtLocation(elements, static_cast<uint64_t>(dim), loc, fieldName);
if (failed(nextElements))
return failure();
elements = *nextElements;
}
return checkedMulAtLocation(
elements, static_cast<uint64_t>(getElementTypeSizeInBytes(type.getElementType())), loc, fieldName);
}
int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName) {
if (value < std::numeric_limits<int32_t>::min() || value > std::numeric_limits<int32_t>::max()) {
emitCrashMessage(fieldName, "is outside representable range");
llvm_unreachable("PIM checked arithmetic failure");
}
return static_cast<int32_t>(value);
}
int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName) {
if (value > static_cast<uint64_t>(std::numeric_limits<int32_t>::max())) {
emitCrashMessage(fieldName, "is outside representable range");
llvm_unreachable("PIM checked arithmetic failure");
}
return static_cast<int32_t>(value);
}
uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName) {
if (value > static_cast<uint64_t>(std::numeric_limits<uint8_t>::max())) {
emitCrashMessage(fieldName, "is outside representable range");
llvm_unreachable("PIM checked arithmetic failure");
}
return static_cast<uint8_t>(value);
}
size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName) {
if (value < 0) {
emitCrashMessage(fieldName, "is outside representable range");
llvm_unreachable("PIM checked arithmetic failure");
}
return static_cast<size_t>(value);
}
size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName) {
if (rhs > std::numeric_limits<size_t>::max() - lhs) {
emitCrashMessage(fieldName, "addition overflow");
llvm_unreachable("PIM checked arithmetic failure");
}
return lhs + rhs;
}
size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName) {
if (lhs != 0 && rhs > std::numeric_limits<size_t>::max() / lhs) {
emitCrashMessage(fieldName, "multiplication overflow");
llvm_unreachable("PIM checked arithmetic failure");
}
return lhs * rhs;
}
} // namespace onnx_mlir::pim
@@ -0,0 +1,107 @@
#pragma once
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/StringRef.h"
#include <cstddef>
#include <cstdint>
#include <limits>
#include <type_traits>
namespace onnx_mlir::pim {
mlir::InFlightDiagnostic
emitCheckedArithmeticError(mlir::Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message);
mlir::InFlightDiagnostic
emitCheckedArithmeticError(mlir::Location loc, llvm::StringRef fieldName, llvm::StringRef message);
template <typename To, typename From>
mlir::FailureOr<To> checkedCast(From value, mlir::Operation* anchor, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<To> && std::is_integral_v<From>, "checkedCast requires integral types");
using ToLimits = std::numeric_limits<To>;
if constexpr (std::is_signed_v<From> == std::is_signed_v<To>) {
if (value < static_cast<From>(ToLimits::min()) || value > static_cast<From>(ToLimits::max())) {
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
return mlir::failure();
}
}
else if constexpr (std::is_signed_v<From>) {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::make_unsigned_t<To>;
if (value < 0 || static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
return mlir::failure();
}
}
else {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::conditional_t<std::is_signed_v<To>, std::make_unsigned_t<To>, To>;
if (static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
return mlir::failure();
}
}
return static_cast<To>(value);
}
template <typename UInt>
mlir::FailureOr<UInt> checkedAdd(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>, "checkedAdd requires unsigned integral types");
if (rhs > std::numeric_limits<UInt>::max() - lhs) {
emitCheckedArithmeticError(anchor, fieldName, "addition overflow");
return mlir::failure();
}
return lhs + rhs;
}
template <typename UInt>
mlir::FailureOr<UInt> checkedMul(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>, "checkedMul requires unsigned integral types");
if (lhs != 0 && rhs > std::numeric_limits<UInt>::max() / lhs) {
emitCheckedArithmeticError(anchor, fieldName, "multiplication overflow");
return mlir::failure();
}
return lhs * rhs;
}
mlir::FailureOr<int32_t> checkedI32(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<int32_t> checkedI32(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<uint8_t> checkedU8(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<size_t> checkedSize(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, int64_t value, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, uint64_t value, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, int64_t value, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, uint64_t value, llvm::StringRef fieldName);
mlir::FailureOr<uint64_t>
getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<uint64_t>
getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Location loc, llvm::StringRef fieldName);
int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName);
int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName);
uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName);
size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName);
size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName);
size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName);
} // namespace onnx_mlir::pim
+1 -1
View File
@@ -18,7 +18,7 @@ void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out); std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file); llvm::raw_os_ostream os(file);
mlir::OpPrintingFlags flags; mlir::OpPrintingFlags flags;
flags.elideLargeElementsAttrs(); flags.elideLargeElementsAttrs().enableDebugInfo(true, false);
moduleOp.print(os, flags); moduleOp.print(os, flags);
os.flush(); os.flush();
file.close(); file.close();
+26
View File
@@ -7,10 +7,36 @@
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include <cstdint>
#include <system_error> #include <system_error>
namespace onnx_mlir::pim { namespace onnx_mlir::pim {
struct CappedDiagnosticReporter {
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8)
: maxReportedFailures(maxReportedFailures) {}
template <typename EmitFn>
void report(mlir::Operation* op, EmitFn&& emit) {
numFailures++;
if (numFailures <= maxReportedFailures)
emit(op);
}
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
if (numFailures > maxReportedFailures)
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
}
void noteFailures(int64_t count) { numFailures += count; }
bool hasFailure() const { return numFailures != 0; }
private:
int64_t maxReportedFailures;
int64_t numFailures = 0;
};
/// Emits a consistent diagnostic for target paths that require static shapes. /// Emits a consistent diagnostic for target paths that require static shapes.
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription); mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
+5 -4
View File
@@ -1,21 +1,22 @@
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "llvm/Support/Format.h" #include "llvm/Support/Format.h"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp" #include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
namespace onnx_mlir { namespace onnx_mlir {
std::fstream openReportFile(const std::string& name) { std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRef extension) {
std::string outputDir = getOutputDir(); std::string outputDir = getOutputDir();
if (outputDir.empty()) if (outputDir.empty())
return {}; return {};
std::string reportsDir = outputDir + "/reports"; std::string reportsDir = outputDir + "/reports";
createDirectory(reportsDir); createDirectory(reportsDir);
return std::fstream(reportsDir + "/" + name + ".txt", std::ios::out); return std::fstream(reportsDir + "/" + name + "." + extension.str(), std::ios::out);
} }
std::fstream openReportFile(const std::string& name) { return openReportFileWithExtension(name, "txt"); }
std::string formatReportMemory(uint64_t bytes) { std::string formatReportMemory(uint64_t bytes) {
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"}; const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
int i = 0; int i = 0;
+2 -2
View File
@@ -1,10 +1,9 @@
#pragma once #pragma once
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include <fstream> #include <fstream>
#include <limits> #include <limits>
#include <string> #include <string>
@@ -12,6 +11,7 @@
namespace onnx_mlir { namespace onnx_mlir {
std::fstream openReportFile(const std::string& name); std::fstream openReportFile(const std::string& name);
std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRef extension);
std::string formatReportMemory(uint64_t bytes); std::string formatReportMemory(uint64_t bytes);
struct ReportField { struct ReportField {
+4 -2
View File
@@ -16,8 +16,8 @@ add_pim_library(OMPimCompilerOptions
add_pim_library(OMPimCompilerUtils add_pim_library(OMPimCompilerUtils
PimCompilerUtils.cpp PimCompilerUtils.cpp
PimArtifactWriter.cpp PimArtifactWriter.cpp
PimBatchEmission.cpp
PimCodeGen.cpp PimCodeGen.cpp
PimMemoryLiveness.cpp
PimWeightEmitter.cpp PimWeightEmitter.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS
@@ -29,7 +29,9 @@ add_pim_library(OMPimCompilerUtils
OMPimCompilerOptions OMPimCompilerOptions
OMPimCommon OMPimCommon
OMPimBufferization OMPimBufferization
OMPimStaticMemoryCoalescing OMPimMemoryCoalescing
OMPimHostConstantFolding
OMPimVerification
OMPimPasses OMPimPasses
OMONNXToSpatial OMONNXToSpatial
OMSpatialToPim OMSpatialToPim
+1 -1
View File
@@ -48,7 +48,7 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
if (!denseAttr) if (!denseAttr)
return; return;
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult()); MemEntry memEntry = memory.hostMem.getMemEntry({getGlobalOp.getResult(), std::nullopt});
ArrayRef<char> rawData = denseAttr.getRawData(); ArrayRef<char> rawData = denseAttr.getRawData();
char* dst = memoryBuffer.data() + memEntry.address; char* dst = memoryBuffer.data() + memEntry.address;
-136
View File
@@ -1,136 +0,0 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
SmallVector<int32_t> laneCoreIds;
laneCoreIds.reserve(coreIds.size() / laneCount);
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
return laneCoreIds;
}
static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) {
IRRewriter rewriter(scalarCore.getContext());
SmallVector<Operation*> batchOps;
scalarCore.walk([&](Operation* op) {
if (isa<pim::PimSendBatchOp,
pim::PimSendTensorBatchOp,
pim::PimReceiveBatchOp,
pim::PimReceiveTensorBatchOp,
pim::PimMemCopyHostToDevBatchOp>(op)) {
batchOps.push_back(op);
}
});
for (Operation* op : batchOps) {
rewriter.setInsertionPoint(op);
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
pim::PimSendOp::create(rewriter,
sendBatchOp.getLoc(),
sendBatchOp.getInput(),
sendBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
rewriter.eraseOp(op);
continue;
}
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
pim::PimSendTensorOp::create(
rewriter,
sendTensorBatchOp.getLoc(),
sendTensorBatchOp.getInput(),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
rewriter.eraseOp(op);
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive =
pim::PimReceiveOp::create(rewriter,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
receiveBatchOp.getOutputBuffer(),
receiveBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
rewriter.replaceOp(op, scalarReceive->getResults());
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveTensorOp::create(
rewriter,
receiveTensorBatchOp.getLoc(),
receiveTensorBatchOp.getOutput().getType(),
receiveTensorBatchOp.getOutputBuffer(),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
rewriter.replaceOp(op, scalarReceive->getResults());
continue;
}
auto memcpBatchOp = cast<pim::PimMemCopyHostToDevBatchOp>(op);
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
memcpBatchOp.getDeviceTarget(),
memcpBatchOp.getHostSource(),
memcpBatchOp.getDeviceTargetOffsetAttr(),
memcpBatchOp.getHostSourceOffsetAttr(),
memcpBatchOp.getSizeAttr());
rewriter.replaceOp(op, scalarCopy->getResults());
}
}
} // namespace
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
OpBuilder builder(scratchModule->getContext());
builder.setInsertionPointToStart(scratchModule->getBody());
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
SmallVector<Value> laneWeights;
laneWeights.reserve(weightsPerLane);
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
auto coreIds = getBatchCoreIds(coreBatchOp);
auto scalarCore = pim::PimCoreOp::create(
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
IRMapping mapper;
if (coreBatchOp.getBody().front().getNumArguments() == 1)
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
builder.setInsertionPointToEnd(block);
for (Operation& op : coreBatchOp.getBody().front()) {
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
scalarizeBatchOpsInCore(scalarCore, laneCount, lane);
return callback(scalarCore);
}
} // namespace onnx_mlir
-13
View File
@@ -1,13 +0,0 @@
#pragma once
#include "llvm/ADT/STLFunctionalExtras.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
} // namespace onnx_mlir
+40 -52
View File
@@ -6,8 +6,8 @@
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include <array> #include <array>
#include <cassert>
#include <limits> #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
namespace onnx_mlir::pim_binary { namespace onnx_mlir::pim_binary {
@@ -70,9 +70,7 @@ inline void writeUint32LE(llvm::raw_ostream& os, uint32_t value) {
os.write(bytes.data(), bytes.size()); os.write(bytes.data(), bytes.size());
} }
inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { writeUint32LE(os, static_cast<uint32_t>(value)); }
writeUint32LE(os, static_cast<uint32_t>(value));
}
inline void writeHeader(llvm::raw_ostream& os) { inline void writeHeader(llvm::raw_ostream& os) {
os.write(kMagic, sizeof(kMagic)); os.write(kMagic, sizeof(kMagic));
@@ -97,15 +95,10 @@ inline void writeInstructionRecord(llvm::raw_ostream& os, const InstructionRecor
writeInt32LE(os, record.generic3); writeInt32LE(os, record.generic3);
} }
inline int32_t toI32(int64_t value) { inline int32_t toI32(int64_t value) { return onnx_mlir::pim::checkedI32OrCrash(value, "binary field"); }
assert(value >= std::numeric_limits<int32_t>::min() && value <= std::numeric_limits<int32_t>::max()
&& "PIM binary field out of int32 range");
return static_cast<int32_t>(value);
}
inline uint8_t toU8(int64_t value) { inline uint8_t toU8(int64_t value) {
assert(value >= 0 && value <= std::numeric_limits<uint8_t>::max() && "PIM binary field out of uint8 range"); return onnx_mlir::pim::checkedU8OrCrash(static_cast<uint64_t>(value), "binary field");
return static_cast<uint8_t>(value);
} }
inline int32_t getOptionalInt(const llvm::json::Object& object, llvm::StringRef key, int32_t defaultValue = 0) { inline int32_t getOptionalInt(const llvm::json::Object& object, llvm::StringRef key, int32_t defaultValue = 0) {
@@ -186,39 +179,39 @@ inline Opcode opcodeFromString(llvm::StringRef opName) {
inline llvm::StringRef opcodeToString(Opcode opcode) { inline llvm::StringRef opcodeToString(Opcode opcode) {
switch (opcode) { switch (opcode) {
case Opcode::nop: return "nop"; case Opcode::nop: return "nop";
case Opcode::sldi: return "sldi"; case Opcode::sldi: return "sldi";
case Opcode::sld: return "sld"; case Opcode::sld: return "sld";
case Opcode::sadd: return "sadd"; case Opcode::sadd: return "sadd";
case Opcode::ssub: return "ssub"; case Opcode::ssub: return "ssub";
case Opcode::smul: return "smul"; case Opcode::smul: return "smul";
case Opcode::saddi: return "saddi"; case Opcode::saddi: return "saddi";
case Opcode::smuli: return "smuli"; case Opcode::smuli: return "smuli";
case Opcode::setbw: return "setbw"; case Opcode::setbw: return "setbw";
case Opcode::mvmul: return "mvmul"; case Opcode::mvmul: return "mvmul";
case Opcode::vvadd: return "vvadd"; case Opcode::vvadd: return "vvadd";
case Opcode::vvsub: return "vvsub"; case Opcode::vvsub: return "vvsub";
case Opcode::vvmul: return "vvmul"; case Opcode::vvmul: return "vvmul";
case Opcode::vvdmul: return "vvdmul"; case Opcode::vvdmul: return "vvdmul";
case Opcode::vvmax: return "vvmax"; case Opcode::vvmax: return "vvmax";
case Opcode::vvsll: return "vvsll"; case Opcode::vvsll: return "vvsll";
case Opcode::vvsra: return "vvsra"; case Opcode::vvsra: return "vvsra";
case Opcode::vavg: return "vavg"; case Opcode::vavg: return "vavg";
case Opcode::vrelu: return "vrelu"; case Opcode::vrelu: return "vrelu";
case Opcode::vtanh: return "vtanh"; case Opcode::vtanh: return "vtanh";
case Opcode::vsigm: return "vsigm"; case Opcode::vsigm: return "vsigm";
case Opcode::vsoftmax: return "vsoftmax"; case Opcode::vsoftmax: return "vsoftmax";
case Opcode::vmv: return "vmv"; case Opcode::vmv: return "vmv";
case Opcode::vrsu: return "vrsu"; case Opcode::vrsu: return "vrsu";
case Opcode::vrsl: return "vrsl"; case Opcode::vrsl: return "vrsl";
case Opcode::ld: return "ld"; case Opcode::ld: return "ld";
case Opcode::st: return "st"; case Opcode::st: return "st";
case Opcode::lldi: return "lldi"; case Opcode::lldi: return "lldi";
case Opcode::lmv: return "lmv"; case Opcode::lmv: return "lmv";
case Opcode::send: return "send"; case Opcode::send: return "send";
case Opcode::recv: return "recv"; case Opcode::recv: return "recv";
case Opcode::wait: return "wait"; case Opcode::wait: return "wait";
case Opcode::sync: return "sync"; case Opcode::sync: return "sync";
} }
llvm_unreachable("Unsupported PIM binary opcode"); llvm_unreachable("Unsupported PIM binary opcode");
} }
@@ -235,9 +228,7 @@ inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruc
case Opcode::sldi: case Opcode::sldi:
case Opcode::saddi: case Opcode::saddi:
case Opcode::smuli: case Opcode::smuli:
case Opcode::lldi: case Opcode::lldi: record.r2OrImm = getOptionalInt(instruction, "imm"); break;
record.r2OrImm = getOptionalInt(instruction, "imm");
break;
case Opcode::mvmul: case Opcode::mvmul:
record.r2OrImm = getOptionalInt(instruction, "mbiw"); record.r2OrImm = getOptionalInt(instruction, "mbiw");
record.generic1 = getOptionalInt(instruction, "relu"); record.generic1 = getOptionalInt(instruction, "relu");
@@ -252,9 +243,7 @@ inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruc
record.r2OrImm = getOptionalInt(instruction, "core"); record.r2OrImm = getOptionalInt(instruction, "core");
record.generic3 = getOptionalInt(instruction, "size"); record.generic3 = getOptionalInt(instruction, "size");
break; break;
default: default: record.r2OrImm = getOptionalInt(instruction, "rs2"); break;
record.r2OrImm = getOptionalInt(instruction, "rs2");
break;
} }
if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) { if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) {
@@ -371,8 +360,7 @@ inline llvm::json::Object makeInstructionJson(const InstructionRecord& record) {
break; break;
case Opcode::wait: case Opcode::wait:
case Opcode::sync: case Opcode::sync:
case Opcode::nop: case Opcode::nop: break;
break;
} }
return instruction; return instruction;
File diff suppressed because it is too large Load Diff
+93 -13
View File
@@ -4,13 +4,18 @@
#include "llvm-project/clang/include/clang/Basic/LLVM.h" #include "llvm-project/clang/include/clang/Basic/LLVM.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/JSON.h" #include "llvm/Support/JSON.h"
#include "llvm/Support/raw_os_ostream.h" #include "llvm/Support/raw_os_ostream.h"
#include <fstream> #include <fstream>
#include <limits>
#include <optional> #include <optional>
#include <string>
#include "onnx-mlir/Compiler/OMCompilerTypes.h" #include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp" #include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp" #include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
@@ -23,6 +28,23 @@ struct MemEntry {
size_t size; size_t size;
}; };
struct PhysicalSlotInfo {
size_t id = 0;
size_t address = 0;
size_t size = 0;
};
struct MemoryPlanArtifacts {
std::string textReport;
};
struct MemoryValueKey {
mlir::Value value;
std::optional<unsigned> lane;
bool operator==(const MemoryValueKey& other) const { return value == other.value && lane == other.lane; }
};
struct MemoryReportRow { struct MemoryReportRow {
uint64_t numAlloca = 0; uint64_t numAlloca = 0;
uint64_t sizeAlloca = 0; uint64_t sizeAlloca = 0;
@@ -35,6 +57,19 @@ struct MemoryReportRow {
} }
}; };
enum class MemoryReportKind {
None,
Alloca,
Global,
Input
};
struct PendingMemEntry {
MemEntry memEntry;
MemoryValueKey key;
MemoryReportKind reportKind = MemoryReportKind::None;
};
struct MemoryReportEntry { struct MemoryReportEntry {
enum class Kind { enum class Kind {
Core, Core,
@@ -50,33 +85,39 @@ struct MemoryReportEntry {
}; };
class PimMemory { class PimMemory {
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries; llvm::SmallVector<PendingMemEntry, 32> memEntries;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap; llvm::SmallVector<PhysicalSlotInfo, 32> localPhysicalSlots;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> ownedMemEntriesMap; llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> ownedMemEntriesMap;
MemoryReportRow reportRow;
MemoryPlanArtifacts livenessArtifacts;
size_t minAlignment = 4; size_t minAlignment = 4;
size_t firstAvailableAddress = 0; size_t firstAvailableAddress = 0;
size_t nextPhysicalSlotId = 0;
MemEntry* gatherMemEntry(mlir::Value value); MemEntry* gatherMemEntry(mlir::Value value, std::optional<unsigned> lane = std::nullopt);
void allocateGatheredMemory(); void allocateGatheredMemory();
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry); void allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry, MemoryReportKind reportKind);
PhysicalSlotInfo allocatePhysicalSlot(size_t slotSize, const MemoryValueKey& key);
public: public:
PimMemory(llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap) PimMemory(llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap)
: globalMemEntriesMap(globalMemEntriesMap) {} : globalMemEntriesMap(globalMemEntriesMap) {}
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp); void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
void allocateCore(mlir::Operation* op); void allocateCore(mlir::Operation* op, std::optional<unsigned> lane = std::nullopt);
MemoryReportRow getReportRow() const; MemoryReportRow getReportRow() const;
const MemoryPlanArtifacts& getLivenessArtifacts() const { return livenessArtifacts; }
void remove(mlir::Value val); void remove(mlir::Value val);
size_t getFirstAvailableAddress() const { return firstAvailableAddress; } size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
MemEntry getMemEntry(mlir::Value value) const; MemEntry getMemEntry(const MemoryValueKey& key) const;
}; };
class PimAcceleratorMemory { class PimAcceleratorMemory {
public: public:
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> memEntriesMap; llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> memEntriesMap;
PimMemory hostMem; PimMemory hostMem;
private: private:
@@ -84,14 +125,24 @@ private:
std::fstream fileReport; std::fstream fileReport;
std::optional<MemoryReportRow> hostReportRow; std::optional<MemoryReportRow> hostReportRow;
llvm::SmallVector<MemoryReportEntry, 32> reportEntries; llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
uint64_t totalWeightBytes = 0;
mutable llvm::DenseMap<mlir::Value, CompiledIndexExpr> compiledIndexExprs;
mutable llvm::DenseMap<mlir::Value, CompiledAddressExpr> compiledAddressExprs;
public: public:
PimAcceleratorMemory() PimAcceleratorMemory()
: hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {} : hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {}
PimAcceleratorMemory(const llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& initialMemEntries, bool enableReport)
: memEntriesMap(initialMemEntries),
hostMem(memEntriesMap),
fileReport(enableReport ? openReportFile("memory_report") : std::fstream()) {}
PimMemory& getOrCreateDeviceMem(size_t id); PimMemory& getOrCreateDeviceMem(size_t id);
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const; size_t getValueAddress(mlir::Value value,
const StaticValueKnowledge& knowledge = {},
std::optional<unsigned> lane = std::nullopt) const;
llvm::FailureOr<int64_t> getIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
void reportHost(); void reportHost();
void recordCoreReport(size_t coreId, const MemoryReportRow& row); void recordCoreReport(size_t coreId, const MemoryReportRow& row);
void recordBatchReport(uint64_t batchId, void recordBatchReport(uint64_t batchId,
@@ -99,19 +150,29 @@ public:
const MemoryReportRow& perCoreRow, const MemoryReportRow& perCoreRow,
uint64_t totalAllocaCount, uint64_t totalAllocaCount,
uint64_t totalAllocaBytes); uint64_t totalAllocaBytes);
void setTotalWeightBytes(uint64_t bytes) { totalWeightBytes = bytes; }
void flushReport(); void flushReport();
void clean(mlir::Operation* op); void clean(mlir::Operation* op);
}; };
struct CoreEmissionJob {
mlir::Operation* coreLikeOp = nullptr;
size_t originalCoreId = 0;
size_t emittedCoreId = 0;
llvm::SmallVector<unsigned, 4> lanes;
std::optional<uint64_t> batchReportId;
};
class PimCodeGen { class PimCodeGen {
PimAcceleratorMemory& memory; PimAcceleratorMemory& memory;
llvm::raw_fd_ostream& coreBinaryStream; llvm::raw_fd_ostream& coreBinaryStream;
llvm::raw_fd_ostream* coreJsonStream; llvm::raw_fd_ostream* coreJsonStream;
const llvm::DenseMap<size_t, size_t>& emittedCoreIds; const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
std::optional<unsigned> batchLane;
mutable uint32_t emittedInstructionCount = 0; mutable uint32_t emittedInstructionCount = 0;
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const { size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getValueAddress(value, knowledge); return memory.getValueAddress(value, knowledge, batchLane);
} }
size_t remapCoreId(size_t coreId) const; size_t remapCoreId(size_t coreId) const;
@@ -141,15 +202,17 @@ public:
: memory(memory), coreBinaryStream(coreBinary), coreJsonStream(coreJson), emittedCoreIds(emittedCoreIds) {} : memory(memory), coreBinaryStream(coreBinary), coreJsonStream(coreJson), emittedCoreIds(emittedCoreIds) {}
uint32_t getEmittedInstructionCount() const { return emittedInstructionCount; } uint32_t getEmittedInstructionCount() const { return emittedInstructionCount; }
void setBatchLane(std::optional<unsigned> lane) { batchLane = lane; }
llvm::FailureOr<int64_t> indexOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getIndexValue(value, knowledge);
}
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const; void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const; void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const; void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const; void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const; void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const; void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
template <typename MVMTy> template <typename MVMTy>
@@ -172,3 +235,20 @@ public:
OnnxMlirCompilerErrorCodes compileToPimCode(mlir::ModuleOp& moduleOpRef, std::string& outputDirName); OnnxMlirCompilerErrorCodes compileToPimCode(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
} // namespace onnx_mlir } // namespace onnx_mlir
namespace llvm {
template <>
struct DenseMapInfo<onnx_mlir::MemoryValueKey> {
static onnx_mlir::MemoryValueKey getEmptyKey() { return {DenseMapInfo<mlir::Value>::getEmptyKey(), 0}; }
static onnx_mlir::MemoryValueKey getTombstoneKey() { return {DenseMapInfo<mlir::Value>::getTombstoneKey(), 0}; }
static unsigned getHashValue(const onnx_mlir::MemoryValueKey& key) {
return hash_combine(key.value, key.lane.value_or(std::numeric_limits<unsigned>::max()));
}
static bool isEqual(const onnx_mlir::MemoryValueKey& lhs, const onnx_mlir::MemoryValueKey& rhs) { return lhs == rhs; }
};
} // namespace llvm
+36 -8
View File
@@ -1,3 +1,5 @@
#include "llvm/Support/ErrorHandling.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#define DEBUG_TYPE "PimCompilerOptions" #define DEBUG_TYPE "PimCompilerOptions"
@@ -13,12 +15,35 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
llvm::cl::init(EmitPimCodegen), llvm::cl::init(EmitPimCodegen),
llvm::cl::cat(OnnxMlirOptions)); llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimMergeSchedulerType>
pimMergeScheduler("pim-merge-scheduler",
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
llvm::cl::init(MergeSchedulerPeft),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport(
"pim-memory-report",
llvm::cl::desc("Emit a human-readable PIM memory planning report"),
llvm::cl::values(clEnumValN(PimMemoryReportNone, "none", "Do not emit any PIM memory planning report")),
llvm::cl::values(
clEnumValN(PimMemoryReportSummary, "summary", "Emit a concise slot reuse report with key offenders")),
llvm::cl::values(clEnumValN(PimMemoryReportFull, "full", "Emit the full detailed PIM memory planning report")),
llvm::cl::init(PimMemoryReportNone),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> llvm::cl::opt<bool>
pimOnlyCodegen("pim-only-codegen", pimOnlyCodegen("pim-only-codegen",
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"), llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
llvm::cl::init(false), llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions)); llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool>
pimDisableMemoryCoalescing("pim-disable-memory-coalescing",
llvm::cl::desc("Skip the PIM memory coalescing pass (developer diagnostic option)"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl", llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl",
llvm::cl::desc("Use experimental implementation for convolution"), llvm::cl::desc("Use experimental implementation for convolution"),
llvm::cl::init(false), llvm::cl::init(false),
@@ -30,24 +55,27 @@ llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
llvm::cl::cat(OnnxMlirOptions)); llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<size_t> llvm::cl::opt<size_t>
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2)); crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
llvm::cl::opt<size_t> llvm::cl::opt<size_t>
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256)); crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
llvm::cl::opt<long> coresCount("core-count", llvm::cl::opt<long> coresCount("core-count",
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."), llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
llvm::cl::init(-1)); llvm::cl::init(-1));
llvm::cl::opt<size_t> dcpCriticalWindowSize(
"dcp-critical-window-size",
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
"Use 0 to run the legacy full-graph DCP analysis."),
llvm::cl::init(4000));
llvm::cl::opt<bool> llvm::cl::opt<bool>
ignoreConcatError("ignore-concat-error", ignoreConcatError("ignore-concat-error",
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"), llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
llvm::cl::init(false)); llvm::cl::init(false));
bool hasExplicitPimCoreCount() { return coresCount.getNumOccurrences() != 0; }
void verifyExplicitPimCoreCount() {
if (!hasExplicitPimCoreCount())
llvm::report_fatal_error("PIM compilation requires an explicit --core-count=<positive integer>");
if (coresCount.getValue() <= 0)
llvm::report_fatal_error("PIM compilation requires --core-count to be a positive integer");
}
} // namespace onnx_mlir } // namespace onnx_mlir
+16 -1
View File
@@ -20,17 +20,32 @@ typedef enum {
EmitPimCodegen = 3 EmitPimCodegen = 3
} PimEmissionTargetType; } PimEmissionTargetType;
typedef enum {
MergeSchedulerPeft = 0,
} PimMergeSchedulerType;
typedef enum {
PimMemoryReportNone = 0,
PimMemoryReportSummary = 1,
PimMemoryReportFull = 2,
} PimMemoryReportLevel;
extern llvm::cl::OptionCategory OnnxMlirOptions; extern llvm::cl::OptionCategory OnnxMlirOptions;
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget; extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
extern llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport;
extern llvm::cl::opt<bool> pimOnlyCodegen; extern llvm::cl::opt<bool> pimOnlyCodegen;
extern llvm::cl::opt<bool> pimDisableMemoryCoalescing;
extern llvm::cl::opt<bool> useExperimentalConvImpl; extern llvm::cl::opt<bool> useExperimentalConvImpl;
extern llvm::cl::opt<bool> pimEmitJson; extern llvm::cl::opt<bool> pimEmitJson;
extern llvm::cl::opt<size_t> crossbarSize; extern llvm::cl::opt<size_t> crossbarSize;
extern llvm::cl::opt<size_t> crossbarCountInCore; extern llvm::cl::opt<size_t> crossbarCountInCore;
extern llvm::cl::opt<long> coresCount; extern llvm::cl::opt<long> coresCount;
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
bool hasExplicitPimCoreCount();
void verifyExplicitPimCoreCount();
// This option, by default set to false, will ignore an error when resolving a // This option, by default set to false, will ignore an error when resolving a
// specific tiles of the operands of a concat. This specific case is when the // specific tiles of the operands of a concat. This specific case is when the
+3 -6
View File
@@ -17,6 +17,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
PassManager& pm, PassManager& pm,
EmissionTargetType& emissionTarget, EmissionTargetType& emissionTarget,
std::string outputNameNoExt) { std::string outputNameNoExt) {
verifyExplicitPimCoreCount();
if (pimOnlyCodegen) { if (pimOnlyCodegen) {
// Skip all the lowering passes and directly generate code for PIM. // Skip all the lowering passes and directly generate code for PIM.
@@ -29,31 +30,27 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitSpatial) { if (pimEmissionTarget >= EmitSpatial) {
pm.addPass(createONNXToSpatialPass()); pm.addPass(createONNXToSpatialPass());
pm.addPass(createMergeComputeNodesPass()); pm.addPass(createMergeComputeNodesPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Onnx lowered to Spatial")); pm.addPass(createMessagePass("Onnx lowered to Spatial"));
} }
if (pimEmissionTarget >= EmitPim) { if (pimEmissionTarget >= EmitPim) {
pm.addPass(createSpatialToPimPass()); pm.addPass(createSpatialToPimPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Spatial lowered to Pim")); pm.addPass(createMessagePass("Spatial lowered to Pim"));
} }
if (pimEmissionTarget >= EmitPimBufferized) { if (pimEmissionTarget >= EmitPimBufferized) {
pm.addPass(createPimBufferizationPass()); pm.addPass(createPimBufferizationPass());
pm.addPass(createPimStaticMemoryCoalescingPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim bufferized")); pm.addPass(createMessagePass("Pim bufferized"));
} }
if (pimEmissionTarget >= EmitPimCodegen) { if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimHostConstantFoldingPass()); pm.addPass(createPimHostConstantFoldingPass());
pm.addPass(createMessagePass("Pim host constants folded")); pm.addPass(createMessagePass("Pim host constants folded"));
pm.addPass(createPimMaterializeHostConstantsPass()); if (!pimDisableMemoryCoalescing)
pm.addPass(createPimMemoryCoalescingPass());
pm.addPass(createPimVerificationPass()); pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified")); pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimCodePass()); pm.addPass(createEmitPimCodePass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim code emitted")); pm.addPass(createMessagePass("Pim code emitted"));
} }
} }
+733
View File
@@ -0,0 +1,733 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>
#include <string>
#include <tuple>
#include <utility>
#include "Common/Support/CheckedArithmetic.hpp"
#include "Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimMemoryLiveness.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm;
using namespace mlir;
using namespace onnx_mlir;
namespace {
static std::optional<unsigned> getLaneForMemoryValue(mlir::Value value, std::optional<unsigned> lane) {
if (!lane)
return std::nullopt;
auto allocOp = value.getDefiningOp<memref::AllocOp>();
if (!allocOp || !allocOp->getParentOfType<pim::PimCoreBatchOp>())
return std::nullopt;
return lane;
}
static MemoryValueKey getMemoryValueKey(mlir::Value value, std::optional<unsigned> lane = std::nullopt) {
return {value, getLaneForMemoryValue(value, lane)};
}
struct MemoryTouchInterval {
uint64_t start = 0;
uint64_t end = 0;
Operation* startOp = nullptr;
Operation* endOp = nullptr;
Operation* firstTouchOp = nullptr;
Operation* lastTouchOp = nullptr;
uint64_t firstTouchPosition = 0;
uint64_t lastTouchPosition = 0;
bool hasRuntimeUse = false;
bool startUsedAllocFallback = false;
bool endUsedFallback = false;
bool escapesLoop = false;
std::string fallbackReason;
llvm::SmallVector<std::string, 8> aliasesFollowed;
};
struct OperationOrdering {
llvm::DenseMap<Operation*, uint64_t> position;
llvm::DenseMap<Operation*, uint64_t> subtreeEnd;
uint64_t nextPosition = 0;
};
static std::string printValueToString(mlir::Value value) {
std::string text;
llvm::raw_string_ostream os(text);
value.print(os);
os.flush();
return text;
}
static std::string printOperationToString(Operation* op) {
if (!op)
return "<none>";
std::string text;
llvm::raw_string_ostream os(text);
op->print(os);
os.flush();
return text;
}
static std::string printLocationToString(Location loc) {
std::string text;
llvm::raw_string_ostream os(text);
loc.print(os);
os.flush();
return text;
}
static std::string collapseWhitespace(StringRef text) {
std::string out;
out.reserve(text.size());
bool lastWasSpace = false;
for (char c : text) {
bool isSpace = c == ' ' || c == '\n' || c == '\t' || c == '\r';
if (isSpace) {
if (!lastWasSpace && !out.empty())
out.push_back(' ');
lastWasSpace = true;
continue;
}
out.push_back(c);
lastWasSpace = false;
}
return out;
}
static std::string abbreviate(StringRef text, size_t maxLen) {
if (text.size() <= maxLen)
return text.str();
return (text.take_front(maxLen - 3) + "...").str();
}
static std::string summarizeValue(mlir::Value value, size_t maxLen = 72) {
return abbreviate(collapseWhitespace(printValueToString(value)), maxLen);
}
static std::string summarizeOperation(Operation* op, size_t maxLen = 96) {
if (!op)
return "<none>";
std::string prefix = op->getName().getStringRef().str();
std::string full = collapseWhitespace(printOperationToString(op));
if (full == prefix)
return prefix;
return abbreviate(prefix + " :: " + full, maxLen);
}
static std::string summarizeLocation(Location loc, size_t maxLen = 88) {
return abbreviate(collapseWhitespace(printLocationToString(loc)), maxLen);
}
static void assignOperationOrdering(Operation* op, OperationOrdering& ordering) {
uint64_t position = ordering.nextPosition++;
ordering.position[op] = position;
uint64_t end = position;
for (Region& region : op->getRegions())
for (Block& block : region)
for (Operation& nestedOp : block) {
assignOperationOrdering(&nestedOp, ordering);
end = std::max(end, ordering.subtreeEnd.lookup(&nestedOp));
}
ordering.subtreeEnd[op] = end;
}
static OperationOrdering buildOperationOrdering(Operation* coreLikeOp) {
OperationOrdering ordering;
if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty())
return ordering;
for (Operation& op : coreLikeOp->getRegion(0).front())
assignOperationOrdering(&op, ordering);
return ordering;
}
static bool isSupportedAliasOp(Operation* op) {
return isa<memref::SubViewOp, memref::CastOp, memref::CollapseShapeOp, memref::ExpandShapeOp>(op);
}
static bool isRuntimeMemoryTouchOp(Operation* op) {
return isa<pim::PimMemCopyHostToDevOp,
pim::PimMemCopyDevToHostOp,
pim::PimMemCopyOp,
pim::PimReceiveOp,
pim::PimSendOp,
pim::PimConcatOp,
pim::PimVMMOp,
pim::PimTransposeOp,
pim::PimVVAddOp,
pim::PimVVSubOp,
pim::PimVVMulOp,
pim::PimVVMaxOp,
pim::PimVVDMulOp,
pim::PimVAvgOp,
pim::PimVReluOp,
pim::PimVTanhOp,
pim::PimVSigmOp,
pim::PimVSoftmaxOp>(op);
}
static bool isIgnoredLivenessUser(Operation* op) {
return isSupportedAliasOp(op) || isa<scf::ForOp, scf::YieldOp, memref::DeallocOp>(op) || isCoreStaticAddressOp(op);
}
static bool isWithin(mlir::Value value, Region* region) {
if (!region)
return false;
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParent() == region;
if (Operation* definingOp = value.getDefiningOp())
return definingOp->getParentRegion() == region || region->isAncestor(definingOp->getParentRegion());
return false;
}
static bool isNestedAllocation(Operation* coreLikeOp, memref::AllocOp allocOp) {
if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty())
return false;
return allocOp->getBlock() != &coreLikeOp->getRegion(0).front();
}
static void addFallbackReason(std::string& reason, StringRef newReason) {
if (newReason.empty())
return;
if (!reason.empty())
reason += "; ";
reason += newReason.str();
}
static void appendAliasDescription(llvm::SmallVectorImpl<std::string>& aliases, mlir::Value value) {
std::string text = printValueToString(value);
if (!llvm::is_contained(aliases, text))
aliases.push_back(std::move(text));
}
struct OrderedTouchRange {
uint64_t start = 0;
uint64_t end = 0;
Operation* startOp = nullptr;
Operation* endOp = nullptr;
bool escapedLoop = false;
};
static OrderedTouchRange
getEffectiveTouchRange(mlir::Value definingValue, Operation* user, const OperationOrdering& ordering) {
OrderedTouchRange range {ordering.position.lookup(user), ordering.position.lookup(user), user, user, false};
for (Operation* current = user; current; current = current->getParentOp()) {
auto forOp = dyn_cast<scf::ForOp>(current);
if (!forOp || isWithin(definingValue, &forOp.getRegion()))
continue;
range.start = std::min(range.start, ordering.position.lookup(forOp));
range.end = std::max(range.end, ordering.subtreeEnd.lookup(forOp));
range.startOp = forOp;
range.endOp = forOp;
range.escapedLoop = true;
}
return range;
}
static MemoryTouchInterval
computeMemoryTouchInterval(memref::AllocOp allocOp, const OperationOrdering& ordering, uint64_t fallbackEnd) {
MemoryTouchInterval interval;
interval.start = ordering.position.lookup(allocOp);
interval.end = interval.start;
interval.startOp = allocOp;
interval.endOp = allocOp;
SmallPtrSet<mlir::Value, 16> visitedValues;
SmallPtrSet<Operation*, 32> visitedUsers;
SmallVector<mlir::Value> pendingValues;
pendingValues.push_back(allocOp.getResult());
auto parentLoop = allocOp->getParentOfType<scf::ForOp>();
while (!pendingValues.empty()) {
mlir::Value value = pendingValues.pop_back_val();
if (!visitedValues.insert(value).second)
continue;
for (Operation* user : value.getUsers()) {
if (!visitedUsers.insert(user).second)
continue;
if (isSupportedAliasOp(user)) {
for (mlir::Value result : user->getResults()) {
pendingValues.push_back(result);
appendAliasDescription(interval.aliasesFollowed, result);
}
}
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
for (OpResult result : user->getResults()) {
OpOperand* tiedOperand = dpsOp.getTiedOpOperand(result);
if (!tiedOperand || tiedOperand->get() != value)
continue;
pendingValues.push_back(result);
appendAliasDescription(interval.aliasesFollowed, result);
}
}
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) {
if (initArg != value)
continue;
pendingValues.push_back(forOp.getRegionIterArgs()[index]);
pendingValues.push_back(forOp.getResult(index));
appendAliasDescription(interval.aliasesFollowed, forOp.getRegionIterArgs()[index]);
appendAliasDescription(interval.aliasesFollowed, forOp.getResult(index));
if (parentLoop && forOp != parentLoop)
interval.escapesLoop = true;
}
}
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
if (!forOp) {
addFallbackReason(interval.fallbackReason, "yield without scf.for parent");
}
else {
for (auto [index, operand] : llvm::enumerate(yieldOp.getOperands())) {
if (operand != value)
continue;
pendingValues.push_back(forOp.getResult(index));
appendAliasDescription(interval.aliasesFollowed, forOp.getResult(index));
if (parentLoop && forOp == parentLoop)
interval.escapesLoop = true;
}
}
}
if (isRuntimeMemoryTouchOp(user)) {
uint64_t touchPosition = ordering.position.lookup(user);
if (!interval.hasRuntimeUse || touchPosition < interval.firstTouchPosition) {
interval.firstTouchPosition = touchPosition;
interval.firstTouchOp = user;
}
if (!interval.hasRuntimeUse || touchPosition > interval.lastTouchPosition) {
interval.lastTouchPosition = touchPosition;
interval.lastTouchOp = user;
}
OrderedTouchRange range = getEffectiveTouchRange(allocOp.getResult(), user, ordering);
interval.escapesLoop |= range.escapedLoop;
if (!interval.hasRuntimeUse) {
interval.start = range.start;
interval.end = range.end;
interval.startOp = range.startOp;
interval.endOp = range.endOp;
interval.hasRuntimeUse = true;
}
else {
if (range.start < interval.start) {
interval.start = range.start;
interval.startOp = range.startOp;
}
if (range.end > interval.end) {
interval.end = range.end;
interval.endOp = range.endOp;
}
}
continue;
}
if (isIgnoredLivenessUser(user))
continue;
addFallbackReason(interval.fallbackReason, "unhandled user op");
interval.endUsedFallback = true;
}
}
if (!interval.hasRuntimeUse) {
interval.startUsedAllocFallback = true;
interval.endUsedFallback = true;
interval.start = ordering.position.lookup(allocOp);
interval.end = fallbackEnd;
interval.startOp = allocOp;
interval.endOp = allocOp->getParentOp();
interval.firstTouchPosition = interval.start;
interval.lastTouchPosition = interval.end;
addFallbackReason(interval.fallbackReason, "no runtime memory touch");
return interval;
}
if (interval.endUsedFallback) {
interval.end = std::max(interval.end, fallbackEnd);
interval.endOp = allocOp->getParentOp();
}
return interval;
}
static FailureOr<size_t> getAllocSizeBytes(memref::AllocOp allocOp) {
auto type = dyn_cast<ShapedType>(allocOp.getType());
if (!type)
return failure();
auto checkedBytes = pim::getCheckedShapedTypeSizeInBytes(type, allocOp, "memory allocation byte size");
if (failed(checkedBytes))
return failure();
return pim::checkedSize(*checkedBytes, allocOp, "memory allocation byte size");
}
static bool intervalsOverlap(const LocalAllocInterval& lhs, const LocalAllocInterval& rhs) {
return !(lhs.end < rhs.start || rhs.end < lhs.start);
}
static uint64_t getSlotLogicalBytes(const PlannedPhysicalSlot& slot, ArrayRef<LocalAllocInterval> intervals) {
uint64_t slotLogicalBytes = 0;
for (size_t intervalIndex : slot.intervalIndices)
slotLogicalBytes += intervals[intervalIndex].size;
return slotLogicalBytes;
}
} // namespace
SmallVector<LocalAllocInterval, 0> onnx_mlir::buildLocalAllocIntervals(Operation* coreLikeOp,
std::optional<unsigned> lane) {
SmallVector<LocalAllocInterval, 0> intervals;
OperationOrdering ordering = buildOperationOrdering(coreLikeOp);
if (ordering.position.empty())
return intervals;
uint64_t fallbackEnd = ordering.nextPosition == 0 ? 0 : ordering.nextPosition - 1;
size_t nextIntervalId = 0;
coreLikeOp->walk([&](memref::AllocOp allocOp) {
auto checkedSize = getAllocSizeBytes(allocOp);
if (failed(checkedSize)) {
llvm::errs() << "Failed to compute local allocation size for value: ";
allocOp.getResult().print(llvm::errs());
llvm::errs() << "\n";
llvm_unreachable("Failed to compute local allocation size");
}
MemoryTouchInterval touchInterval = computeMemoryTouchInterval(allocOp, ordering, fallbackEnd);
LocalAllocInterval interval;
interval.id = nextIntervalId++;
interval.alloc = allocOp;
interval.key = getMemoryValueKey(allocOp.getResult(), lane);
interval.start = touchInterval.start;
interval.end = touchInterval.end;
interval.size = *checkedSize;
interval.startOp = touchInterval.startOp;
interval.endOp = touchInterval.endOp;
interval.firstTouchOp = touchInterval.firstTouchOp;
interval.lastTouchOp = touchInterval.lastTouchOp;
interval.firstTouchPosition = touchInterval.firstTouchPosition;
interval.lastTouchPosition = touchInterval.lastTouchPosition;
interval.startUsedAllocFallback = touchInterval.startUsedAllocFallback;
interval.endUsedFallback = touchInterval.endUsedFallback;
interval.hasRuntimeUse = touchInterval.hasRuntimeUse;
interval.insideNestedRegion = isNestedAllocation(coreLikeOp, allocOp);
interval.escapesLoop = touchInterval.escapesLoop;
interval.fallbackReason = std::move(touchInterval.fallbackReason);
interval.aliasesFollowed = std::move(touchInterval.aliasesFollowed);
intervals.push_back(std::move(interval));
});
return intervals;
}
SmallVector<PlannedPhysicalSlot, 0> onnx_mlir::planPhysicalSlots(MutableArrayRef<LocalAllocInterval> intervals) {
SmallVector<PlannedPhysicalSlot, 0> slots;
SmallVector<size_t> intervalOrder(intervals.size());
std::iota(intervalOrder.begin(), intervalOrder.end(), 0);
llvm::stable_sort(intervalOrder, [&](size_t lhsIndex, size_t rhsIndex) {
const LocalAllocInterval& lhs = intervals[lhsIndex];
const LocalAllocInterval& rhs = intervals[rhsIndex];
if (lhs.size != rhs.size)
return lhs.size > rhs.size;
if (lhs.start != rhs.start)
return lhs.start < rhs.start;
if (lhs.end != rhs.end)
return lhs.end < rhs.end;
return lhs.id < rhs.id;
});
for (size_t intervalIndex : intervalOrder) {
LocalAllocInterval& interval = intervals[intervalIndex];
PlannedPhysicalSlot* bestSlot = nullptr;
auto bestKey = std::tuple<size_t, size_t, size_t, size_t>(std::numeric_limits<size_t>::max(),
std::numeric_limits<size_t>::max(),
std::numeric_limits<size_t>::max(),
std::numeric_limits<size_t>::max());
for (size_t slotIndex = 0; slotIndex < slots.size(); ++slotIndex) {
PlannedPhysicalSlot& slot = slots[slotIndex];
bool compatible = true;
for (size_t otherIndex : slot.intervalIndices) {
if (intervalsOverlap(interval, intervals[otherIndex])) {
compatible = false;
break;
}
}
if (!compatible)
continue;
size_t resultingSize = std::max(slot.requiredSize, interval.size);
size_t growth = resultingSize - slot.requiredSize;
auto candidateKey =
std::tuple<size_t, size_t, size_t, size_t>(growth, resultingSize, slot.intervalIndices.size(), slot.id);
if (candidateKey < bestKey) {
bestKey = candidateKey;
bestSlot = &slot;
}
}
if (!bestSlot) {
slots.push_back({slots.size(), interval.size, interval.size, 0, {intervalIndex}});
interval.slotPlanIndex = slots.size() - 1;
interval.physicalSlotId = slots.back().id;
interval.physicalSlotSize = slots.back().requiredSize;
continue;
}
bestSlot->requiredSize = std::max(bestSlot->requiredSize, interval.size);
bestSlot->size = bestSlot->requiredSize;
bestSlot->intervalIndices.push_back(intervalIndex);
interval.slotPlanIndex = static_cast<size_t>(bestSlot - slots.data());
interval.physicalSlotId = bestSlot->id;
interval.physicalSlotSize = bestSlot->requiredSize;
}
return slots;
}
MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation* coreLikeOp,
std::optional<unsigned> lane,
ArrayRef<LocalAllocInterval> intervals,
ArrayRef<PlannedPhysicalSlot> slots,
size_t addressLimit,
PimMemoryReportLevel reportLevel) {
MemoryPlanArtifacts artifacts;
uint64_t totalLogicalBytes = 0;
uint64_t totalPhysicalBytes = 0;
uint64_t fallbackIntervals = 0;
uint64_t noRuntimeTouchIntervals = 0;
uint64_t reusedAllocations = 0;
uint64_t nestedIntervals = 0;
uint64_t loopEscapingIntervals = 0;
size_t largestLogicalAllocation = 0;
size_t largestPhysicalSlot = 0;
size_t maximumAssignedAddress = 0;
for (const LocalAllocInterval& interval : intervals) {
totalLogicalBytes += interval.size;
largestLogicalAllocation = std::max(largestLogicalAllocation, interval.size);
maximumAssignedAddress = std::max(maximumAssignedAddress, interval.assignedAddress + interval.physicalSlotSize);
if (interval.startUsedAllocFallback || interval.endUsedFallback)
++fallbackIntervals;
if (!interval.hasRuntimeUse)
++noRuntimeTouchIntervals;
if (interval.insideNestedRegion)
++nestedIntervals;
if (interval.escapesLoop)
++loopEscapingIntervals;
}
for (const PlannedPhysicalSlot& slot : slots) {
totalPhysicalBytes += slot.size;
largestPhysicalSlot = std::max(largestPhysicalSlot, slot.size);
if (slot.intervalIndices.size() > 1)
reusedAllocations += slot.intervalIndices.size() - 1;
}
uint64_t savedBytes = totalLogicalBytes >= totalPhysicalBytes ? totalLogicalBytes - totalPhysicalBytes : 0;
double savedPercent =
totalLogicalBytes == 0 ? 0.0 : 100.0 * static_cast<double>(savedBytes) / static_cast<double>(totalLogicalBytes);
raw_string_ostream os(artifacts.textReport);
os << "=== PIM Memory Liveness Report ===\n";
os << "Op: " << coreLikeOp->getName() << "\n";
if (lane)
os << "Lane: " << *lane << "\n";
os << "Summary:\n";
os << " logical allocation bytes: " << formatReportMemory(totalLogicalBytes) << " (" << totalLogicalBytes << ")\n";
os << " physical allocation bytes: " << formatReportMemory(totalPhysicalBytes) << " (" << totalPhysicalBytes
<< ")\n";
os << " saved bytes: " << formatReportMemory(savedBytes) << " (" << savedBytes << ")\n";
os << " saved percent: " << format("%.2f%%", savedPercent) << "\n";
os << " intervals: " << intervals.size() << "\n";
os << " physical slots: " << slots.size() << "\n";
os << " reused allocations: " << reusedAllocations << "\n";
os << " fallback intervals: " << fallbackIntervals << "\n";
os << " intervals with no runtime memory touch: " << noRuntimeTouchIntervals << "\n";
os << " nested allocations: " << nestedIntervals << "\n";
os << " loop-escaping allocations: " << loopEscapingIntervals << "\n";
os << " largest logical allocation: " << largestLogicalAllocation << "\n";
os << " largest physical slot: " << largestPhysicalSlot << "\n";
os << " address limit: " << addressLimit << "\n";
os << " peak physical memory: " << formatReportMemory(maximumAssignedAddress) << " (" << maximumAssignedAddress
<< ")\n";
os << " maximum assigned address: " << maximumAssignedAddress << "\n";
os << "\nHow To Read:\n";
os << " `summary` only shows the strongest reuse cases and the worst offenders.\n";
os << " Use `--pim-memory-report=full` when you need the complete slot-by-slot and interval-by-interval dump.\n";
os << " Large single-use slots, fallback intervals, and nested single-use allocations are the best places\n";
os << " to inspect if allocations should be moved, sunk, or made easier to coalesce earlier in the pipeline.\n";
SmallVector<const PlannedPhysicalSlot*> reusedSlots;
SmallVector<const PlannedPhysicalSlot*> singleUseSlots;
for (const PlannedPhysicalSlot& slot : slots)
if (slot.intervalIndices.size() > 1)
reusedSlots.push_back(&slot);
else
singleUseSlots.push_back(&slot);
llvm::stable_sort(reusedSlots, [&](const PlannedPhysicalSlot* lhs, const PlannedPhysicalSlot* rhs) {
uint64_t lhsLogicalBytes = getSlotLogicalBytes(*lhs, intervals);
uint64_t rhsLogicalBytes = getSlotLogicalBytes(*rhs, intervals);
if (lhs->intervalIndices.size() != rhs->intervalIndices.size())
return lhs->intervalIndices.size() > rhs->intervalIndices.size();
if (lhsLogicalBytes != rhsLogicalBytes)
return lhsLogicalBytes > rhsLogicalBytes;
if (lhs->size != rhs->size)
return lhs->size > rhs->size;
return lhs->id < rhs->id;
});
llvm::stable_sort(singleUseSlots, [&](const PlannedPhysicalSlot* lhs, const PlannedPhysicalSlot* rhs) {
if (lhs->size != rhs->size)
return lhs->size > rhs->size;
return lhs->id < rhs->id;
});
constexpr size_t kSummaryReuseLimit = 6;
constexpr size_t kSummaryOffenderLimit = 10;
os << "\nBest Reuse:\n";
if (reusedSlots.empty()) {
os << " no slots were shared by multiple intervals\n";
}
else {
for (const PlannedPhysicalSlot* slot : ArrayRef(reusedSlots).take_front(kSummaryReuseLimit)) {
uint64_t slotLogicalBytes = getSlotLogicalBytes(*slot, intervals);
os << " slot #" << slot->id << " addr=" << slot->address << " size=" << formatReportMemory(slot->size)
<< " intervals=" << slot->intervalIndices.size() << " logical_sum=" << formatReportMemory(slotLogicalBytes)
<< "\n";
for (size_t intervalIndex : slot->intervalIndices) {
const LocalAllocInterval& interval = intervals[intervalIndex];
os << " #" << interval.id << " [" << interval.start << "," << interval.end << "]"
<< " logical=" << formatReportMemory(interval.size)
<< " first=" << summarizeOperation(interval.firstTouchOp, 40)
<< " last=" << summarizeOperation(interval.lastTouchOp, 40) << "\n";
}
}
}
os << "\nTop Offenders:\n";
bool printedAttention = false;
for (const PlannedPhysicalSlot* slot : ArrayRef(singleUseSlots).take_front(kSummaryOffenderLimit)) {
const LocalAllocInterval& interval = intervals[slot->intervalIndices.front()];
printedAttention = true;
os << " slot #" << slot->id << " is single-use"
<< " size=" << formatReportMemory(slot->size) << " interval=#" << interval.id
<< " value=" << summarizeValue(interval.key.value, 56) << "\n";
os << " first=" << summarizeOperation(interval.firstTouchOp, 40)
<< " last=" << summarizeOperation(interval.lastTouchOp, 40)
<< " nested=" << (interval.insideNestedRegion ? "yes" : "no")
<< " escapes_loop=" << (interval.escapesLoop ? "yes" : "no") << "\n";
}
size_t fallbackPrinted = 0;
for (const LocalAllocInterval& interval : intervals) {
if (!(interval.startUsedAllocFallback || interval.endUsedFallback) || fallbackPrinted >= kSummaryOffenderLimit)
continue;
printedAttention = true;
++fallbackPrinted;
os << " fallback interval #" << interval.id << " size=" << formatReportMemory(interval.size)
<< " value=" << summarizeValue(interval.key.value, 56) << "\n";
os << " reason: " << (interval.fallbackReason.empty() ? "<none>" : interval.fallbackReason) << "\n";
}
size_t nestedPrinted = 0;
for (const LocalAllocInterval& interval : intervals) {
if (nestedPrinted >= kSummaryOffenderLimit)
break;
if (!(interval.insideNestedRegion && slots[interval.slotPlanIndex].intervalIndices.size() == 1))
continue;
printedAttention = true;
++nestedPrinted;
os << " nested single-use interval #" << interval.id << " slot #" << interval.physicalSlotId
<< " size=" << formatReportMemory(interval.size) << " value=" << summarizeValue(interval.key.value, 56)
<< "\n";
os << " hint: move or sink this alloc inside the nested region if the IR allows it.\n";
}
if (!printedAttention)
os << " no obvious blockers detected in this core\n";
if (reportLevel == PimMemoryReportFull) {
os << "\nSlot Reuse:\n";
for (const PlannedPhysicalSlot& slot : slots) {
uint64_t slotLogicalBytes = getSlotLogicalBytes(slot, intervals);
os << " slot #" << slot.id << " addr=" << slot.address << " size=" << formatReportMemory(slot.size) << " ("
<< slot.size << ")"
<< " intervals=" << slot.intervalIndices.size() << " logical_sum=" << formatReportMemory(slotLogicalBytes)
<< "\n";
for (size_t intervalIndex : slot.intervalIndices) {
const LocalAllocInterval& interval = intervals[intervalIndex];
mlir::Value allocValue = interval.key.value;
os << " [" << interval.start << "," << interval.end << "]"
<< " #" << interval.id << " logical=" << formatReportMemory(interval.size)
<< " nested=" << (interval.insideNestedRegion ? "yes" : "no")
<< " escapes_loop=" << (interval.escapesLoop ? "yes" : "no")
<< " first=" << summarizeOperation(interval.firstTouchOp, 48)
<< " last=" << summarizeOperation(interval.lastTouchOp, 48) << "\n";
os << " value=" << summarizeValue(allocValue) << "\n";
}
}
}
if (reportLevel == PimMemoryReportFull) {
os << "\nInterval Details:\n";
for (const LocalAllocInterval& interval : intervals) {
const PlannedPhysicalSlot& slot = slots[interval.slotPlanIndex];
mlir::Value allocValue = interval.key.value;
Operation* definingOp = allocValue.getDefiningOp();
os << " #" << interval.id << " slot=" << slot.id << " live=[" << interval.start << "," << interval.end << "]"
<< " logical=" << formatReportMemory(interval.size)
<< " slot_size=" << formatReportMemory(interval.physicalSlotSize) << " addr=" << interval.assignedAddress
<< "\n";
os << " value=" << summarizeValue(allocValue, 88) << "\n";
os << " type=" << allocValue.getType() << "\n";
os << " loc="
<< summarizeLocation(definingOp ? definingOp->getLoc() : UnknownLoc::get(coreLikeOp->getContext())) << "\n";
os << " nested=" << (interval.insideNestedRegion ? "yes" : "no")
<< " escapes_loop=" << (interval.escapesLoop ? "yes" : "no")
<< " start_fallback=" << (interval.startUsedAllocFallback ? "yes" : "no")
<< " end_fallback=" << (interval.endUsedFallback ? "yes" : "no") << "\n";
os << " first_use=" << summarizeOperation(interval.firstTouchOp) << " @" << interval.firstTouchPosition
<< "\n";
os << " last_use=" << summarizeOperation(interval.lastTouchOp) << " @" << interval.lastTouchPosition << "\n";
os << " slot_peers=";
bool first = true;
for (size_t otherIndex : slot.intervalIndices) {
if (intervals[otherIndex].id == interval.id)
continue;
if (!first)
os << ", ";
os << "#" << intervals[otherIndex].id;
first = false;
}
if (first)
os << "<none>";
os << "\n";
if (!interval.fallbackReason.empty())
os << " fallback_reason=" << interval.fallbackReason << "\n";
if (!interval.aliasesFollowed.empty()) {
os << " aliases_followed=" << interval.aliasesFollowed.size() << "\n";
for (const std::string& alias : interval.aliasesFollowed)
os << " - " << abbreviate(collapseWhitespace(alias), 108) << "\n";
}
}
}
os.flush();
return artifacts;
}
+63
View File
@@ -0,0 +1,63 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include <limits>
#include <optional>
#include <string>
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
namespace onnx_mlir {
struct LocalAllocInterval {
size_t id = 0;
mlir::memref::AllocOp alloc;
MemoryValueKey key;
uint64_t start = 0;
uint64_t end = 0;
size_t size = 0;
mlir::Operation* startOp = nullptr;
mlir::Operation* endOp = nullptr;
mlir::Operation* firstTouchOp = nullptr;
mlir::Operation* lastTouchOp = nullptr;
uint64_t firstTouchPosition = 0;
uint64_t lastTouchPosition = 0;
bool startUsedAllocFallback = false;
bool endUsedFallback = false;
bool hasRuntimeUse = false;
bool insideNestedRegion = false;
bool escapesLoop = false;
std::string fallbackReason;
llvm::SmallVector<std::string, 8> aliasesFollowed;
size_t slotPlanIndex = std::numeric_limits<size_t>::max();
size_t physicalSlotId = std::numeric_limits<size_t>::max();
size_t assignedAddress = 0;
size_t physicalSlotSize = 0;
};
struct PlannedPhysicalSlot {
size_t id = std::numeric_limits<size_t>::max();
size_t requiredSize = 0;
size_t size = 0;
size_t address = 0;
llvm::SmallVector<size_t, 8> intervalIndices;
};
llvm::SmallVector<LocalAllocInterval, 0> buildLocalAllocIntervals(mlir::Operation* coreLikeOp,
std::optional<unsigned> lane);
llvm::SmallVector<PlannedPhysicalSlot, 0> planPhysicalSlots(llvm::MutableArrayRef<LocalAllocInterval> intervals);
MemoryPlanArtifacts buildMemoryPlanArtifacts(mlir::Operation* coreLikeOp,
std::optional<unsigned> lane,
llvm::ArrayRef<LocalAllocInterval> intervals,
llvm::ArrayRef<PlannedPhysicalSlot> slots,
size_t addressLimit,
PimMemoryReportLevel reportLevel);
} // namespace onnx_mlir
+57 -173
View File
@@ -1,209 +1,93 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FileSystem.h" #include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include <cassert> #include <cassert>
#include "Common/Support/CheckedArithmetic.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp" #include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm; using namespace llvm;
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {} // namespace
struct DenseWeightView { WeightEmissionResult createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) {
DenseElementsAttr denseAttr;
SmallVector<int64_t> shape;
SmallVector<int64_t> strides;
int64_t offset = 0;
};
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
SmallVector<memref::SubViewOp> subviews;
mlir::Value current = weight;
memref::GetGlobalOp getGlobalOp;
while (true) {
Operation* defOp = current.getDefiningOp();
if (!defOp)
return failure();
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
break;
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!hasAllStaticSubviewParts(subview))
return failure();
subviews.push_back(subview);
current = subview.getSource();
continue;
}
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
current = cast.getSource();
continue;
}
return failure();
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
DenseWeightView view;
view.denseAttr = denseAttr;
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStrides(view.shape);
for (memref::SubViewOp subview : llvm::reverse(subviews)) {
SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getStaticStrides().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
view.offset += offset * sourceStride;
nextStrides.push_back(stride * sourceStride);
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
}
return view;
}
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) {
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
};
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
llvm::sort(indices);
return indices;
}
SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
return getUsedWeightIndices(coreOp.getBody().front());
}
SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
SmallVector<Operation*> coreLikeOps;
for (Operation& op : funcOp.getBody().front())
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
coreLikeOps.push_back(&op);
return coreLikeOps;
}
} // namespace
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
auto coreWeightsDirPath = outputDirPath + "/weights"; auto coreWeightsDirPath = outputDirPath + "/weights";
auto error = sys::fs::create_directory(coreWeightsDirPath); auto error = sys::fs::create_directory(coreWeightsDirPath);
assert(!error && "Error creating weights directory"); assert(!error && "Error creating weights directory");
size_t indexFileName = 0; size_t indexFileName = 0;
int64_t xbarSize = crossbarSize.getValue(); int64_t xbarSize = crossbarSize.getValue();
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName; WeightEmissionResult result;
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName; llvm::SmallVector<std::pair<ResolvedWeightView, std::string>, 16> materializedWeights;
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp); auto materializeWeight = [&](const ResolvedWeightView& weightView) -> std::string {
if (auto it = llvm::find_if(materializedWeights, [&](const auto& entry) { return entry.first == weightView; });
it != materializedWeights.end())
return it->second;
for (Operation* op : coreLikeOps) { auto globalOp = weightView.globalOp;
auto processCore = [&](pim::PimCoreOp coreOp) { auto denseAttr = mlir::dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
size_t coreId = static_cast<size_t>(coreOp.getCoreId()); assert(denseAttr && "Weight global must have dense initial value");
for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
auto weightView = resolveDenseWeightView(moduleOp, weight); ArrayRef<int64_t> shape = weightView.shape;
if (failed(weightView)) { assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index)); int64_t numRows = shape[0];
assert(succeeded(weightView) && "Weight is not from a dense memref.global view"); int64_t numCols = shape[1];
} assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
if (mapCoreWeightToFileName[coreId].contains(weight)) size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
continue;
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>(); std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {}; auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) { std::error_code errorCode;
auto& fileName = mapGlobalOpToFileName[globalOp]; raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
mapCoreWeightToFileName[coreId].insert({weight, fileName}); if (errorCode) {
continue; errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
} assert(errorCode);
DenseElementsAttr denseAttr = weightView->denseAttr;
ArrayRef<int64_t> shape = weightView->shape;
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
std::error_code errorCode;
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
assert(errorCode);
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
if (globalOp)
mapGlobalOpToFileName.insert({globalOp, newFileName});
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
}
return success();
};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
(void) processCore(coreOp);
continue;
} }
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op); uint64_t zero = 0;
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane) for (int64_t row = 0; row < xbarSize; row++) {
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, processCore))) for (int64_t col = 0; col < xbarSize; col++) {
return mapCoreWeightToFileName; if (row < numRows && col < numCols) {
int64_t elementIndex = weightView.offset + row * weightView.strides[0] + col * weightView.strides[1];
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
materializedWeights.push_back({weightView, newFileName});
uint64_t weightBytes = pim::checkedMulOrCrash(
pim::checkedMulOrCrash(static_cast<size_t>(xbarSize), static_cast<size_t>(xbarSize), "weight element count"),
elementByteWidth,
"weight byte size");
result.totalWeightBytes = pim::checkedAddOrCrash(result.totalWeightBytes, weightBytes, "total weight bytes");
return newFileName;
};
for (const WeightFileRequest& request : requests) {
auto& coreFiles = result.mapCoreWeightToFileName[request.coreId];
coreFiles.reserve(request.weights.size());
for (const ResolvedWeightView& weight : request.weights)
coreFiles.push_back(materializeWeight(weight));
} }
return mapCoreWeightToFileName;
return result;
} }
} // namespace onnx_mlir } // namespace onnx_mlir
+16 -3
View File
@@ -1,16 +1,29 @@
#pragma once #pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringRef.h"
#include <cstdint>
#include <string> #include <string>
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
namespace onnx_mlir { namespace onnx_mlir {
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> struct WeightFileRequest {
createAndPopulateWeightFolder(mlir::func::FuncOp funcOp, llvm::StringRef outputDirPath); size_t coreId = 0;
llvm::SmallVector<ResolvedWeightView, 8> weights;
};
struct WeightEmissionResult {
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>> mapCoreWeightToFileName;
uint64_t totalWeightBytes = 0;
};
WeightEmissionResult createAndPopulateWeightFolder(llvm::ArrayRef<WeightFileRequest> requests,
llvm::StringRef outputDirPath);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -3,11 +3,12 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen) add_public_tablegen_target(ONNXToSpatialIncGen)
add_pim_library(OMONNXToSpatial add_pim_library(OMONNXToSpatial
ConversionPatterns.cpp Patterns.cpp
HostFoldability.cpp CompileTime.cpp
HostLegality.cpp ONNXToSpatialVerifier.cpp
PrePatterns.cpp Patterns/Pre.cpp
PostPatterns.cpp Patterns/Post.cpp
Patterns/GeneratedConversion.cpp
Patterns/Math/Conv.cpp Patterns/Math/Conv.cpp
Patterns/Math/Elementwise.cpp Patterns/Math/Elementwise.cpp
Patterns/Math/Gemm.cpp Patterns/Math/Gemm.cpp
@@ -21,9 +22,13 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Gather.cpp Patterns/Tensor/Gather.cpp
Patterns/Tensor/Resize.cpp Patterns/Tensor/Resize.cpp
Patterns/Tensor/Reshape.cpp Patterns/Tensor/Reshape.cpp
Patterns/Tensor/Slice.cpp
Patterns/Tensor/Split.cpp Patterns/Tensor/Split.cpp
Patterns/Tensor/Transpose.cpp
ONNXToSpatialPass.cpp ONNXToSpatialPass.cpp
Common/AttributeUtils.cpp
Common/ComputeRegionBuilder.cpp Common/ComputeRegionBuilder.cpp
Common/IndexingUtils.cpp
Common/ShapeTilingUtils.cpp Common/ShapeTilingUtils.cpp
Common/WeightMaterialization.cpp Common/WeightMaterialization.cpp
@@ -33,6 +38,7 @@ add_pim_library(OMONNXToSpatial
ONNXToSpatialIncGen ONNXToSpatialIncGen
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRLinalgDialect
MLIRSCFDialect MLIRSCFDialect
MLIRTosaDialect MLIRTosaDialect
OMCompilerOptions OMCompilerOptions
@@ -0,0 +1,23 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "AttributeUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
int64_t getI64Attr(ArrayAttr attr, size_t index) { return cast<IntegerAttr>(attr[index]).getInt(); }
int64_t getOptionalI64Attr(std::optional<ArrayAttr> attr, size_t index, int64_t defaultValue) {
return attr ? getI64Attr(*attr, index) : defaultValue;
}
llvm::SmallVector<int64_t> getI64ArrayAttrValues(ArrayAttr attr) {
llvm::SmallVector<int64_t> values;
values.reserve(attr.size());
for (Attribute value : attr)
values.push_back(cast<IntegerAttr>(value).getInt());
return values;
}
} // namespace onnx_mlir
@@ -0,0 +1,18 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <optional>
namespace onnx_mlir {
int64_t getI64Attr(mlir::ArrayAttr attr, size_t index);
int64_t getOptionalI64Attr(std::optional<mlir::ArrayAttr> attr, size_t index, int64_t defaultValue);
llvm::SmallVector<int64_t> getI64ArrayAttrValues(mlir::ArrayAttr attr);
} // namespace onnx_mlir
@@ -1,6 +1,8 @@
#pragma once #pragma once
#include "AttributeUtils.hpp"
#include "ComputeRegionBuilder.hpp" #include "ComputeRegionBuilder.hpp"
#include "IndexingUtils.hpp"
#include "ShapeTilingUtils.hpp" #include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp" #include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -1,5 +1,6 @@
#pragma once #pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h" #include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ValueRange.h" #include "mlir/IR/ValueRange.h"
@@ -7,9 +8,12 @@
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <limits>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
@@ -18,13 +22,17 @@ namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); } inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) {
return mlir::ValueRange(block->getArguments()).drop_front(weightCount);
}
template <typename Fn, size_t... Is> template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) { decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...); return std::forward<Fn>(fn)(block->getArgument(Is)...);
} }
template <typename Fn, size_t... Is> template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) { decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...); return std::forward<Fn>(fn)(values[Is]...);
} }
@@ -45,6 +53,13 @@ using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::
template <typename Fn> template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>; using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
struct SpatComputeBatchBodyArgs {
mlir::Value lane;
mlir::ValueRange weights;
mlir::ValueRange inputs;
mlir::ValueRange outputs;
};
} // namespace detail } // namespace detail
template <typename RewriterT> template <typename RewriterT>
@@ -85,6 +100,8 @@ auto createSpatCompute(RewriterT& rewriter,
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block(); auto* block = new mlir::Block();
for (mlir::Value weight : weights)
block->addArgument(weight.getType(), loc);
for (mlir::Value input : inputs) for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc); block->addArgument(input.getType(), loc);
@@ -93,14 +110,17 @@ auto createSpatCompute(RewriterT& rewriter,
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>; using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, void>) { if constexpr (std::is_same_v<BodyResult, void>) {
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {}); detail::invokeWithValues(std::forward<BodyFn>(body),
detail::getInputBlockArgs(block, weights.size()),
std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
return computeOp; return computeOp;
} }
else { else {
auto bodyResult = auto bodyResult = detail::invokeWithValues(std::forward<BodyFn>(body),
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {}); detail::getInputBlockArgs(block, weights.size()),
std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) { if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp); rewriter.eraseOp(computeOp);
@@ -123,6 +143,8 @@ auto createSpatCompute(RewriterT& rewriter,
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs); auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block(); auto* block = new mlir::Block();
for (mlir::Value weight : weights)
block->addArgument(weight.getType(), loc);
for (mlir::Value input : inputs) for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc); block->addArgument(input.getType(), loc);
@@ -131,13 +153,13 @@ auto createSpatCompute(RewriterT& rewriter,
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>; using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, void>) { if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(detail::getBlockArgs(block)); std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
return computeOp; return computeOp;
} }
else { else {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block)); auto bodyResult = std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
if (mlir::failed(bodyResult)) { if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp); rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp); rewriter.eraseOp(computeOp);
@@ -148,6 +170,98 @@ auto createSpatCompute(RewriterT& rewriter,
} }
} }
template <typename RewriterT, typename BodyFn>
auto createSpatComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count");
if (mlir::failed(laneCountAttr))
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
blockArgTypes.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
blockArgLocs.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
for (mlir::Value weight : weights) {
blockArgTypes.push_back(weight.getType());
blockArgLocs.push_back(weight.getLoc());
}
for (mlir::Value input : inputs) {
blockArgTypes.push_back(input.getType());
blockArgLocs.push_back(input.getLoc());
}
for (mlir::Type resultType : resultTypes) {
blockArgTypes.push_back(resultType);
blockArgLocs.push_back(loc);
}
auto* block =
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), mlir::TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToStart(block);
detail::SpatComputeBatchBodyArgs args {
block->getArgument(0),
mlir::ValueRange(block->getArguments()).slice(1, weights.size()),
mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()),
mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size())};
using BodyResult = std::invoke_result_t<BodyFn, detail::SpatComputeBatchBodyArgs>;
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(args);
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
}
else {
auto bodyResult = std::forward<BodyFn>(body)(args);
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(batchOp);
rewriter.eraseOp(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
}
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
}
}
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
mlir::Value dest,
mlir::ArrayRef<mlir::OpFoldResult> offsets,
mlir::ArrayRef<mlir::OpFoldResult> sizes,
mlir::ArrayRef<mlir::OpFoldResult> strides) {
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
mlir::tensor::ParallelInsertSliceOp::create(rewriter, loc, source, dest, offsets, sizes, strides);
}
template <typename BodyFn>
mlir::Value materializeOrComputeUnary(mlir::Value input,
mlir::RankedTensorType resultType,
mlir::PatternRewriter& rewriter,
mlir::Location loc,
BodyFn&& build) {
auto&& buildFn = build;
if (isCompileTimeComputable(input))
return buildFn(input);
auto computeOp = createSpatCompute<1>(
rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) {
mlir::Value result = buildFn(computeInput);
spatial::SpatYieldOp::create(rewriter, loc, result);
});
return computeOp.getResult(0);
}
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter); mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -0,0 +1,45 @@
#include <algorithm>
#include "IndexingUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank) {
int64_t normalizedAxis = normalizeAxis(axis, rank);
if (normalizedAxis < 0 || normalizedAxis >= rank)
return failure();
return normalizedAxis;
}
int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
static SmallVector<int64_t> normalizeAxesImpl(std::optional<ArrayAttr> axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes;
if (!axesAttr) {
normalizedAxes.reserve(rank);
for (int64_t axis = 0; axis < rank; ++axis)
normalizedAxes.push_back(axis);
}
else {
normalizedAxes.reserve(axesAttr->size());
for (Attribute attr : *axesAttr)
normalizedAxes.push_back(normalizeAxis(cast<IntegerAttr>(attr).getInt(), rank));
llvm::sort(normalizedAxes);
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
}
return normalizedAxes;
}
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes = normalizeAxesImpl(axesAttr, rank);
for (int64_t axis : normalizedAxes)
if (axis < 0 || axis >= rank)
return failure();
return normalizedAxes;
}
} // namespace onnx_mlir
@@ -0,0 +1,20 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SmallVector.h"
#include <optional>
namespace onnx_mlir {
int64_t normalizeAxis(int64_t axis, int64_t rank);
mlir::FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank);
int64_t normalizeIndex(int64_t index, int64_t dimSize);
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
} // namespace onnx_mlir
@@ -3,26 +3,93 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <functional>
#include "IndexingUtils.hpp"
#include "ShapeTilingUtils.hpp" #include "ShapeTilingUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
bool hasStaticPositiveShape(RankedTensorType type) {
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
}
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
SmallVector<int64_t> permutedShape;
permutedShape.reserve(permutation.size());
for (int64_t axis : permutation)
permutedShape.push_back(shape[axis]);
return permutedShape;
}
SmallVector<int64_t> invertPermutation(ArrayRef<int64_t> permutation) {
SmallVector<int64_t> inversePermutation(permutation.size());
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
return inversePermutation;
}
FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<ArrayAttr> permAttr, int64_t rank) {
SmallVector<int64_t> permutation;
if (!permAttr) {
permutation.reserve(rank);
for (int64_t dim = rank - 1; dim >= 0; --dim)
permutation.push_back(dim);
return permutation;
}
if (static_cast<int64_t>(permAttr->size()) != rank)
return failure();
permutation.reserve(permAttr->size());
SmallVector<bool> seen(rank, false);
for (IntegerAttr attr : permAttr->getAsRange<IntegerAttr>()) {
int64_t axis = attr.getInt();
if (axis < 0 || axis >= rank || seen[axis])
return failure();
seen[axis] = true;
permutation.push_back(axis);
}
return permutation;
}
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
}
SmallVector<OpFoldResult> getZeroOffsets(PatternRewriter& rewriter, int64_t rank) {
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(0));
}
SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int64_t> shape) {
SmallVector<OpFoldResult> sizes;
sizes.reserve(shape.size());
for (int64_t dim : shape)
sizes.push_back(rewriter.getIndexAttr(dim));
return sizes;
}
SmallVector<Value> sliceTensor( SmallVector<Value> sliceTensor(
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) { const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(tensorToSlice); ArrayRef<long> shape = getTensorShape(tensorToSlice);
assert("Invalid axis" && axis < shape.size()); assert("Invalid axis" && axis < shape.size());
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1)); SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, shape.size());
SmallVector<OpFoldResult> sizes; SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, shape);
sizes.reserve(shape.size());
for (const auto size : shape)
sizes.push_back(rewriter.getIndexAttr(size));
sizes[axis] = rewriter.getIndexAttr(sliceSize); sizes[axis] = rewriter.getIndexAttr(sliceSize);
long length = shape[axis]; long length = shape[axis];
@@ -44,7 +111,7 @@ SmallVector<Value> sliceTensor(
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType()); RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
Value slice; Value slice;
if (isHostFoldableValue(tensorToSlice)) { if (isCompileTimeComputable(tensorToSlice)) {
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides); slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
} }
else { else {
@@ -80,38 +147,33 @@ sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewri
return slicesPerCore; return slicesPerCore;
} }
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix( Value extractAxisSlice(
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc) { PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
assert("Not a matrix" && isMatrixShape(getTensorShape(matrixToTile))); auto sourceType = cast<RankedTensorType>(source.getType());
SmallVector<int64_t> resultShape(sourceType.getShape());
resultShape[axis] = size;
auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding());
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tiles; SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, sourceType.getRank());
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
SmallVector<Value> hSlices = sliceTensor(matrixToTile, 1, hSliceSize, rewriter, loc); offsets[axis] = rewriter.getIndexAttr(offset);
size_t numHSlices = hSlices.size(); sizes[axis] = rewriter.getIndexAttr(size);
for (size_t hSliceId = 0; hSliceId < numHSlices; hSliceId++) { return tensor::ExtractSliceOp::create(
Value hSlice = hSlices[hSliceId]; rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
SmallVector<Value> vSlices = sliceTensor(hSlice, 0, vSliceSize, rewriter, loc); .getResult();
for (size_t vSliceId = 0; vSliceId < vSlices.size(); vSliceId++) {
size_t coreId = vSliceId / crossbarCountInCore;
Value vSlice = vSlices[vSliceId];
tiles[hSliceId][coreId].push_back(vSlice);
}
}
return tiles;
} }
tensor::SplatOp Value insertStaticSlice(
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) { PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType()); auto sourceType = cast<RankedTensorType>(source.getType());
Type elementType = oldType.getElementType(); return tensor::InsertSliceOp::create(rewriter,
int64_t shape[2] = {1, length}; loc,
Type type = oldType.cloneWith(ArrayRef(shape), elementType); source,
dest,
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult(); offsets,
SmallVector<Value> index(oldType.getRank(), zero); getStaticSizes(rewriter, sourceType.getShape()),
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult(); getUnitStrides(rewriter, sourceType.getRank()))
.getResult();
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
} }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -3,6 +3,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h" #include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
@@ -11,46 +12,12 @@
#include <cassert> #include <cassert>
#include <cstddef> #include <cstddef>
#include <optional>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
namespace onnx_mlir { namespace onnx_mlir {
template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
using HSliceId = size_t; using HSliceId = size_t;
using CoreId = size_t; using CoreId = size_t;
@@ -87,17 +54,6 @@ bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1; return shape.size() == 2 && shape[0] == 1;
} }
template <class T>
bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(mlir::Value tensor) { inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape(); return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
} }
@@ -109,6 +65,25 @@ inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
&& lhsType.getShape() == rhsType.getShape(); && lhsType.getShape() == rhsType.getShape();
} }
bool hasStaticPositiveShape(mlir::ArrayRef<int64_t> shape);
bool hasStaticPositiveShape(mlir::RankedTensorType type);
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, mlir::ArrayRef<int64_t> shape);
/// Slices a statically shaped tensor along one axis into contiguous pieces of /// Slices a statically shaped tensor along one axis into contiguous pieces of
/// at most `sliceSize` elements. /// at most `sliceSize` elements.
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice, llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
@@ -127,18 +102,13 @@ llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore( llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc); const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
/// Tiles a matrix first across output columns and then across input rows so it mlir::Value extractAxisSlice(
/// can be assigned to crossbars grouped by core. mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast, mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
int64_t length, mlir::Location loc,
mlir::ConversionPatternRewriter& rewriter, mlir::Value source,
mlir::Location loc); mlir::Value dest,
llvm::ArrayRef<mlir::OpFoldResult> offsets);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
@@ -43,8 +44,8 @@ bool isWeightLikeComputeOperand(Value value) {
value = collapseShapeOp.getSrc(); value = collapseShapeOp.getSrc();
continue; continue;
} }
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) { if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
value = transposeOp.getData(); value = transposeOp.getInput();
continue; continue;
} }
@@ -80,7 +81,7 @@ FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr
return referencedValue.getResult(); return referencedValue.getResult();
} }
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp)) if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(definingOp))
return failure(); return failure();
IRMapping localMapper; IRMapping localMapper;
@@ -0,0 +1,299 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include <utility>
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
}
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
return llvm::all_of(extractOp.getIndices(), [](Value index) { return matchConstantIndexValue(index).has_value(); });
}
static bool isStaticTensorResult(Operation* op) {
return llvm::all_of(op->getResultTypes(), [](Type type) {
auto shapedType = dyn_cast<ShapedType>(type);
return shapedType && shapedType.hasStaticShape();
});
}
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!tensorType)
return failure();
int64_t rank = tensorType.getRank();
if (static_cast<int64_t>(perms.size()) != rank)
return failure();
llvm::SmallBitVector seen(rank);
SmallVector<int64_t> transposedShape;
transposedShape.reserve(rank);
for (int64_t perm : perms) {
if (perm < 0 || perm >= rank || seen.test(perm))
return failure();
seen.set(perm);
transposedShape.push_back(tensorType.getShape()[perm]);
}
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding());
if (denseAttr.isSplat())
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> transposedValues(originalValues.size());
SmallVector<int64_t> originalStrides = computeRowMajorStrides(tensorType.getShape());
SmallVector<int64_t> transposedStrides = computeRowMajorStrides(transposedShape);
SmallVector<int64_t> originalIndices(rank);
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
int64_t remaining = static_cast<int64_t>(linearIndex);
for (int64_t dim = 0; dim < rank; ++dim) {
originalIndices[dim] = remaining / originalStrides[dim];
remaining %= originalStrides[dim];
}
int64_t transposedLinearIndex = 0;
for (int64_t dim = 0; dim < rank; ++dim)
transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim];
transposedValues[transposedLinearIndex] = value;
}
return DenseElementsAttr::get(transposedType, transposedValues);
}
static FailureOr<DenseElementsAttr> reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements())
return failure();
if (denseAttr.isSplat())
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> values(denseAttr.getValues<Attribute>());
return DenseElementsAttr::get(resultType, values);
}
static FailureOr<DenseElementsAttr> extractSliceDenseElements(DenseElementsAttr denseAttr,
tensor::ExtractSliceOp extractSliceOp) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
ArrayRef<int64_t> offsets = extractSliceOp.getStaticOffsets();
ArrayRef<int64_t> sizes = extractSliceOp.getStaticSizes();
ArrayRef<int64_t> strides = extractSliceOp.getStaticStrides();
if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); })
|| llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); })
|| llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; }))
return failure();
if (denseAttr.isSplat())
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
SmallVector<Attribute> resultValues;
resultValues.reserve(resultType.getNumElements());
for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) {
int64_t remaining = linearIndex;
int64_t sourceLinearIndex = 0;
for (int64_t dim = 0; dim < resultType.getRank(); ++dim) {
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim];
}
resultValues.push_back(sourceValues[sourceLinearIndex]);
}
return DenseElementsAttr::get(resultType, resultValues);
}
static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
return nullptr;
}
static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
auto* definingOp = value.getDefiningOp();
if (!definingOp || !visited.insert(definingOp).second)
return nullptr;
// Rebuild dense attributes through view-only host-foldable chains so later
// lowering stages can still recognize grouped/sliced constants.
if (auto denseAttr = getDirectDenseConstantAttr(value))
return denseAttr;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getData(), visited);
if (!inputAttr)
return nullptr;
SmallVector<int64_t> perm;
perm.reserve(transposeOp.getPermAttr().size());
for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange<IntegerAttr>())
perm.push_back(attr.getInt());
auto transposedAttr = transposeDenseElements(inputAttr, perm);
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
}
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getInput(), visited);
if (!inputAttr)
return nullptr;
SmallVector<int64_t> perm(transposeOp.getPermutation().begin(), transposeOp.getPermutation().end());
auto transposedAttr = transposeDenseElements(inputAttr, perm);
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
}
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
if (!inputAttr)
return nullptr;
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
return succeeded(slicedAttr) ? *slicedAttr : nullptr;
}
return nullptr;
}
static std::optional<CompileTimeSource>
getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited, size_t chainLength = 0) {
if (!op)
return std::nullopt;
if (!visited.insert(op).second)
return {
{op, chainLength}
};
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
return {
{op, chainLength}
};
chainLength += 1;
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
return hasConstantIndices(extractOp)
? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength)
: std::nullopt;
if (!isStaticTensorResult(op))
return std::nullopt;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength);
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(op))
return getCompileTimeSourceImpl(transposeOp.getInput().getDefiningOp(), visited, chainLength);
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength);
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength);
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return hasStaticUnitStrides(extractSliceOp)
? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength)
: std::nullopt;
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
return getCompileTimeSourceImpl(splatOp.getInput().getDefiningOp(), visited, chainLength);
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
return getCompileTimeSourceImpl(extractRowsOp.getInput().getDefiningOp(), visited, chainLength);
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
std::optional<CompileTimeSource> res = {};
for (auto operandValue : concatOp.getOperands()) {
auto partialRes = getCompileTimeSourceImpl(operandValue.getDefiningOp(), visited, chainLength);
if (!partialRes)
return std::nullopt;
if (!res) {
res = partialRes;
continue;
}
if (res->chainLength < partialRes->chainLength)
res = partialRes;
}
return res;
}
return std::nullopt;
}
} // namespace
std::optional<CompileTimeSource> getCompileTimeSource(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited;
return getCompileTimeSourceImpl(op, visited);
}
bool isCompileTimeComputable(Value value) {
auto* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
return getCompileTimeSourceImpl(definingOp, visited).has_value();
}
bool isCompileTimeOp(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited;
return getCompileTimeSourceImpl(op, visited).has_value();
}
DenseElementsAttr getHostConstDenseElementsAttr(Value value) {
llvm::SmallPtrSet<Operation*, 8> visited;
return getHostConstantDenseElementsAttrImpl(value, visited);
}
} // namespace onnx_mlir
@@ -0,0 +1,22 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
struct CompileTimeSource {
mlir::Operation* source;
size_t chainLength;
};
std::optional<CompileTimeSource> getCompileTimeSource(mlir::Operation* op);
bool isCompileTimeComputable(mlir::Value value);
bool isCompileTimeOp(mlir::Operation* op);
mlir::DenseElementsAttr getHostConstDenseElementsAttr(mlir::Value value);
} // namespace onnx_mlir
@@ -1,75 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
}
static bool isStaticTensorResult(Operation* op) {
return llvm::all_of(op->getResultTypes(), [](Type type) {
auto shapedType = dyn_cast<ShapedType>(type);
return shapedType && shapedType.hasStaticShape();
});
}
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
if (!op || !visited.insert(op).second)
return false;
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
return true;
if (!isStaticTensorResult(op))
return false;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
return isHostFoldableValue(transposeOp.getData());
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
return isHostFoldableValue(collapseShapeOp.getSrc());
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
return isHostFoldableValue(expandShapeOp.getSrc());
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
return isHostFoldableValue(extractRowsOp.getInput());
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op))
return llvm::all_of(concatOp.getInputs(), isHostFoldableValue);
return false;
}
} // namespace
bool isHostFoldableValue(Value value) {
auto* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
return isHostFoldableOpImpl(definingOp, visited);
}
bool isHostFoldableOp(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited;
return isHostFoldableOpImpl(op, visited);
}
} // namespace onnx_mlir
@@ -1,12 +0,0 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
bool isHostFoldableValue(mlir::Value value);
bool isHostFoldableOp(mlir::Operation* op);
} // namespace onnx_mlir
@@ -1,29 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
bool hasFailure = false;
for (Operation& op : funcOp.getFunctionBody().front()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
continue;
if (isHostFoldableOp(&op))
continue;
op.emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
hasFailure = true;
}
return success(!hasFailure);
}
} // namespace onnx_mlir
@@ -1,10 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
namespace onnx_mlir {
mlir::LogicalResult verifyONNXToSpatialHostLegality(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -1,26 +1,22 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "Common/Common.hpp" #include "Common/Common.hpp"
#include "Common/PimCommon.hpp" #include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir; using namespace mlir;
@@ -46,7 +42,8 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext()); IRRewriter rewriter(funcOp.getContext());
IRMapping mapper; IRMapping mapper;
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>()); SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
if (!computes.empty()) SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
if (!computes.empty() || !computeBatches.empty())
return; return;
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator()); auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
@@ -91,13 +88,27 @@ void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation(); ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext(); MLIRContext* ctx = &getContext();
ConversionTarget preTarget(*ctx);
preTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
RewritePatternSet prePatterns(ctx); RewritePatternSet prePatterns(ctx);
populatePrePatterns(prePatterns, ctx); populatePrePatterns(prePatterns, ctx);
if (failed(applyPatternsGreedily(moduleOp, std::move(prePatterns)))) if (failed(applyPartialConversion(moduleOp, preTarget, std::move(prePatterns)))) {
moduleOp.emitWarning("failed to apply ONNX-to-Spatial pre-patterns; continuing"); moduleOp.emitError("failed to apply ONNX-to-Spatial pre-rewrites");
signalPassFailure();
return;
}
auto entryFunc = getPimEntryFunc(moduleOp); auto entryFunc = getPimEntryFunc(moduleOp);
if (failed(entryFunc)) { if (failed(entryFunc)) {
moduleOp.emitError("failed to locate the PIM entry function during ONNX-to-Spatial lowering");
signalPassFailure(); signalPassFailure();
return; return;
} }
@@ -105,11 +116,15 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget target(*ctx); ConversionTarget target(*ctx);
target.addLegalDialect<spatial::SpatialDialect, target.addLegalDialect<spatial::SpatialDialect,
ONNXDialect, ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect, tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect, arith::ArithDialect,
scf::SCFDialect>(); scf::SCFDialect>();
target.addIllegalOp<ONNXMatMulOp>(); target.addIllegalOp<ONNXMatMulOp>();
target.addIllegalOp<ONNXTransposeOp>();
target.addIllegalOp<ONNXAddOp>(); target.addIllegalOp<ONNXAddOp>();
target.addIllegalOp<ONNXSubOp>();
target.addIllegalOp<ONNXDivOp>(); target.addIllegalOp<ONNXDivOp>();
target.addIllegalOp<ONNXMulOp>(); target.addIllegalOp<ONNXMulOp>();
target.addIllegalOp<ONNXGemmOp>(); target.addIllegalOp<ONNXGemmOp>();
@@ -123,6 +138,7 @@ void ONNXToSpatialPass::runOnOperation() {
target.addIllegalOp<ONNXGatherOp>(); target.addIllegalOp<ONNXGatherOp>();
target.addIllegalOp<ONNXReshapeOp>(); target.addIllegalOp<ONNXReshapeOp>();
target.addIllegalOp<ONNXResizeOp>(); target.addIllegalOp<ONNXResizeOp>();
target.addIllegalOp<ONNXSliceOp>();
target.addIllegalOp<ONNXLRNOp>(); target.addIllegalOp<ONNXLRNOp>();
target.addIllegalOp<ONNXReduceMeanV13Op>(); target.addIllegalOp<ONNXReduceMeanV13Op>();
target.addIllegalOp<ONNXSplitOp>(); target.addIllegalOp<ONNXSplitOp>();
@@ -130,30 +146,19 @@ void ONNXToSpatialPass::runOnOperation() {
RewritePatternSet conversionPatterns(ctx); RewritePatternSet conversionPatterns(ctx);
populateConversionPatterns(conversionPatterns, ctx); populateConversionPatterns(conversionPatterns, ctx);
if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) { if (failed(applyPartialConversion(moduleOp, target, std::move(conversionPatterns)))) {
moduleOp.emitError("failed to convert required ONNX ops to Spatial ops");
signalPassFailure(); signalPassFailure();
return; return;
} }
RewritePatternSet earlyPostPatterns(ctx); ConversionTarget earlyPostTarget(*ctx);
populateEarlyPostPatterns(earlyPostPatterns, ctx); earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
if (failed(applyPatternsGreedily(*entryFunc, std::move(earlyPostPatterns)))) { ONNXDialect,
signalPassFailure(); linalg::LinalgDialect,
return; tensor::TensorDialect,
} affine::AffineDialect,
arith::ArithDialect,
if (coresCount != -1) { scf::SCFDialect>();
int computeOpsCount = 0;
for (Operation& op : entryFunc->getFunctionBody().front().getOperations())
if (isa<spatial::SpatCompute>(op))
computeOpsCount++;
if (computeOpsCount > coresCount) {
entryFunc->emitError() << "number of compute ops (" << computeOpsCount << ") exceeds the core count ("
<< coresCount << ")";
signalPassFailure();
return;
}
}
PassManager cleanupPM(ctx); PassManager cleanupPM(ctx);
cleanupPM.addPass(createCanonicalizerPass()); cleanupPM.addPass(createCanonicalizerPass());
@@ -162,14 +167,23 @@ void ONNXToSpatialPass::runOnOperation() {
annotateWeightsConstants(*entryFunc); annotateWeightsConstants(*entryFunc);
ConversionTarget postTarget(*ctx);
postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
[](spatial::SpatCompute computeOp) { return !requiresPostRewrite(computeOp); });
postTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
[](spatial::SpatComputeBatch computeOp) { return !requiresPostRewrite(computeOp); });
RewritePatternSet postPatterns(ctx); RewritePatternSet postPatterns(ctx);
populatePostPatterns(postPatterns, ctx); populatePostPatterns(postPatterns, ctx);
if (failed(applyPatternsGreedily(*entryFunc, std::move(postPatterns)))) { if (failed(applyPartialConversion(*entryFunc, postTarget, std::move(postPatterns)))) {
signalPassFailure(); moduleOp.emitError("failed to normalize weight-like Spatial compute operands before Spatial-to-PIM lowering");
return;
}
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
signalPassFailure(); signalPassFailure();
return; return;
} }
@@ -177,6 +191,11 @@ void ONNXToSpatialPass::runOnOperation() {
populateEmptyFunction(*entryFunc); populateEmptyFunction(*entryFunc);
dumpModule(moduleOp, "spatial0"); dumpModule(moduleOp, "spatial0");
if (failed(verifyONNXToSpatial(*entryFunc))) {
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
signalPassFailure();
}
} }
std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); } std::unique_ptr<Pass> createONNXToSpatialPass() { return std::make_unique<ONNXToSpatialPass>(); }
@@ -0,0 +1,157 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LLVM.h"
#include "Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
func.walk([&](Operation* op) {
if (!hasWeightAlways(op))
return;
for (Value result : op->getResults()) {
if (hasOnlySpatialMvmVmmWeightUses(result))
continue;
diagnostics.report(op, [&](Operation* illegalOp) {
illegalOp->emitOpError(
"weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights");
});
return;
}
});
}
Region* getParentRegion(Value value) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParent();
if (Operation* definingOp = value.getDefiningOp())
return definingOp->getParentRegion();
return nullptr;
}
bool isDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value);
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion));
}
bool isLegalHostBackedValue(Value value) {
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return isa<BlockArgument>(value);
if (isa<spatial::SpatChannelReceiveOp>(definingOp))
return false;
return definingOp->getDialect()->getNamespace() != "spat";
}
LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp,
ValueRange inputs,
bool allowChannelReceiveInputs,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
for (auto [inputIndex, input] : llvm::enumerate(inputs)) {
unsigned currentInputIndex = inputIndex;
Operation* definingOp = input.getDefiningOp();
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
continue;
if (isLegalHostBackedValue(input))
continue;
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError()
<< kind << " input #" << currentInputIndex
<< (allowChannelReceiveInputs ? " must come from the host or an explicit "
"spat.channel_receive"
: " must come from the host");
if (definingOp)
diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName();
});
return failure();
}
return success();
}
void verifyNoExternalTensorCaptures(Operation* ownerOp,
Region& region,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
region.walk([&](Operation* op) {
for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get();
if (!isa<TensorType>(value.getType()))
continue;
if (isDefinedInsideRegion(value, region) || isa<BlockArgument>(value))
continue;
Operation* definingOp = value.getDefiningOp();
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
diagnostics.report(ownerOp, [&](Operation* illegalOp) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor "
<< "values";
diagnostic.attachNote(op->getLoc())
<< "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by "
<< (definingOp ? definingOp->getName().getStringRef() : StringRef("<block argument>"));
});
}
});
}
} // namespace
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
for (Operation& op : funcOp.getOps()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
continue;
if (isCompileTimeOp(&op))
continue;
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError(
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
});
}
checkWeightUseChains(funcOp, diagnostics);
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
return success(!diagnostics.hasFailure());
}
LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
(void) verifyComputeLikeInputs(
computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics);
verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics);
}
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
(void) verifyComputeLikeInputs(computeBatchOp.getOperation(),
computeBatchOp.getInputs(),
/*allowChannelReceiveInputs=*/false,
"spat.compute_batch",
diagnostics);
verifyNoExternalTensorCaptures(
computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics);
}
diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed");
return success(!diagnostics.hasFailure());
}
} // namespace onnx_mlir
@@ -0,0 +1,11 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
namespace onnx_mlir {
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
mlir::LogicalResult verifySpatialCommunicationInvariants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -1,20 +1,16 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { void populatePrePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { populateGeneratedPrePatterns(patterns, ctx); }
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
} // namespace
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<removeLRN>(ctx);
void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateGeneratedConversionPatterns(patterns, ctx);
populateElementwisePatterns(patterns, ctx); populateElementwisePatterns(patterns, ctx);
populateMatMulRewritePatterns(patterns, ctx);
populateGemmPatterns(patterns, ctx); populateGemmPatterns(patterns, ctx);
populateConvPatterns(patterns, ctx); populateConvPatterns(patterns, ctx);
populatePoolPatterns(patterns, ctx); populatePoolPatterns(patterns, ctx);
@@ -26,7 +22,13 @@ void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRCon
populateGatherPatterns(patterns, ctx); populateGatherPatterns(patterns, ctx);
populateResizePatterns(patterns, ctx); populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx); populateReshapePatterns(patterns, ctx);
populateSlicePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx); populateSplitPatterns(patterns, ctx);
populateTransposePatterns(patterns, ctx);
}
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateWeightPromotionPatterns(patterns, ctx);
} }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -1,38 +1,40 @@
#pragma once #pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir { namespace onnx_mlir {
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGeneratedConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateWeightPromotionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReduceMeanPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReduceMeanPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSigmoidPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateSigmoidPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSoftmaxPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateSoftmaxPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSlicePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx); void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
bool requiresPostRewrite(spatial::SpatCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -0,0 +1,18 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
} // namespace
void populateGeneratedConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<removeLRN>(ctx);
}
} // namespace onnx_mlir
File diff suppressed because it is too large Load Diff
@@ -7,7 +7,7 @@
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -83,7 +83,7 @@ static FailureOr<Value> materializeBroadcastedConstantTensor(Value value,
} }
auto broadcastedAttr = DenseElementsAttr::get(resultType, resultValues); auto broadcastedAttr = DenseElementsAttr::get(resultType, resultValues);
return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult(); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), broadcastedAttr, resultType);
} }
static FailureOr<Value> static FailureOr<Value>
@@ -121,7 +121,7 @@ static FailureOr<Value> materializeReciprocalTensor(Value value,
} }
auto reciprocalAttr = DenseFPElementsAttr::get(resultType, reciprocalValues); auto reciprocalAttr = DenseFPElementsAttr::get(resultType, reciprocalValues);
return arith::ConstantOp::create(rewriter, loc, resultType, reciprocalAttr).getResult(); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), reciprocalAttr, resultType);
} }
template <typename OnnxOp, typename SpatialOp> template <typename OnnxOp, typename SpatialOp>
@@ -189,6 +189,7 @@ struct DivToSpatialCompute : OpConversionPattern<ONNXDivOp> {
void populateElementwisePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { void populateElementwisePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<BinaryElementwiseToSpatialCompute<ONNXAddOp, spatial::SpatVAddOp>>(ctx); patterns.add<BinaryElementwiseToSpatialCompute<ONNXAddOp, spatial::SpatVAddOp>>(ctx);
patterns.add<BinaryElementwiseToSpatialCompute<ONNXSubOp, spatial::SpatVSubOp>>(ctx);
patterns.add<BinaryElementwiseToSpatialCompute<ONNXMulOp, spatial::SpatVMulOp>>(ctx); patterns.add<BinaryElementwiseToSpatialCompute<ONNXMulOp, spatial::SpatVMulOp>>(ctx);
patterns.add<DivToSpatialCompute>(ctx); patterns.add<DivToSpatialCompute>(ctx);
} }
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1,13 +1,16 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include <algorithm> #include <algorithm>
#include <numeric>
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -16,26 +19,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes;
if (!axesAttr) {
normalizedAxes.reserve(rank);
for (int64_t axis = 0; axis < rank; axis++)
normalizedAxes.push_back(axis);
return normalizedAxes;
}
normalizedAxes.reserve(axesAttr.size());
for (Attribute attr : axesAttr) {
int64_t axis = cast<IntegerAttr>(attr).getInt();
normalizedAxes.push_back(axis >= 0 ? axis : rank + axis);
}
llvm::sort(normalizedAxes);
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
return normalizedAxes;
}
static SmallVector<bool> buildReducedAxesMask(ArrayRef<int64_t> axes, int64_t rank) { static SmallVector<bool> buildReducedAxesMask(ArrayRef<int64_t> axes, int64_t rank) {
SmallVector<bool> reducedAxes(rank, false); SmallVector<bool> reducedAxes(rank, false);
for (int64_t axis : axes) { for (int64_t axis : axes) {
@@ -50,6 +33,184 @@ static RankedTensorType getAllOnesType(RankedTensorType inputType, Type elementT
return RankedTensorType::get(SmallVector<int64_t>(inputType.getRank(), 1), elementType); return RankedTensorType::get(SmallVector<int64_t>(inputType.getRank(), 1), elementType);
} }
static RankedTensorType getKeepdimsType(RankedTensorType inputType, Type elementType, ArrayRef<bool> reducedAxes) {
SmallVector<int64_t> shape;
shape.reserve(inputType.getRank());
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
shape.push_back(isReduced ? 1 : dim);
return RankedTensorType::get(shape, elementType, inputType.getEncoding());
}
static RankedTensorType getCompactKeptType(RankedTensorType inputType, Type elementType, ArrayRef<bool> reducedAxes) {
SmallVector<int64_t> shape;
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
if (!isReduced)
shape.push_back(dim);
return RankedTensorType::get(shape, elementType, inputType.getEncoding());
}
static RankedTensorType getReducedSliceType(RankedTensorType inputType, ArrayRef<bool> reducedAxes) {
SmallVector<int64_t> shape;
shape.reserve(inputType.getRank());
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
shape.push_back(isReduced ? dim : 1);
return RankedTensorType::get(shape, inputType.getElementType(), inputType.getEncoding());
}
static RankedTensorType getLanePackedKeepdimsType(int64_t laneCount, RankedTensorType leafType) {
SmallVector<int64_t> shape(leafType.getShape().begin(), leafType.getShape().end());
shape.front() = laneCount;
return RankedTensorType::get(shape, leafType.getElementType(), leafType.getEncoding());
}
static SmallVector<int64_t> getKeptAxes(ArrayRef<bool> reducedAxes) {
SmallVector<int64_t> keptAxes;
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes))
if (!isReduced)
keptAxes.push_back(static_cast<int64_t>(axis));
return keptAxes;
}
static Value
computeLaneIndex(Value lane, int64_t stride, int64_t dimSize, ConversionPatternRewriter& rewriter, Location loc) {
if (dimSize == 1)
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
AffineExpr expr = d0;
if (stride != 1)
expr = expr.floorDiv(stride);
if (dimSize != 1)
expr = expr % dimSize;
return createOrFoldAffineApply(rewriter, loc, expr, ValueRange {lane}, rewriter.getInsertionBlock()->getParentOp());
}
static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,
ArrayRef<bool> reducedAxes,
RankedTensorType batchType,
RankedTensorType leafType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
auto sliceType = getReducedSliceType(inputType, reducedAxes);
SmallVector<int64_t> keptAxes = getKeptAxes(reducedAxes);
int64_t laneCount = 1;
SmallVector<int64_t> keptAxisStrides(keptAxes.size(), 1);
for (int64_t index = static_cast<int64_t>(keptAxes.size()) - 1; index >= 0; --index) {
keptAxisStrides[index] = laneCount;
int64_t dimSize = inputType.getDimSize(keptAxes[index]);
if (dimSize <= 0)
return failure();
if (laneCount > std::numeric_limits<int32_t>::max() / dimSize)
return failure();
laneCount *= dimSize;
}
SmallVector<OpFoldResult> sliceOffsets;
SmallVector<OpFoldResult> sliceSizes;
SmallVector<OpFoldResult> insertOffsets;
SmallVector<OpFoldResult> insertSizes(inputType.getRank(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, inputType.getRank());
sliceOffsets.reserve(inputType.getRank());
sliceSizes.reserve(inputType.getRank());
insertOffsets.reserve(inputType.getRank());
auto batchOp =
createSpatComputeBatch(rewriter,
loc,
TypeRange {batchType},
laneCount,
{},
ValueRange {input},
[&](detail::SpatComputeBatchBodyArgs args) {
size_t keptAxisIndex = 0;
sliceOffsets.clear();
sliceSizes.clear();
insertOffsets.clear();
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
if (isReduced) {
sliceOffsets.push_back(rewriter.getIndexAttr(0));
sliceSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(axis)));
continue;
}
Value axisIndex = computeLaneIndex(
args.lane, keptAxisStrides[keptAxisIndex], inputType.getDimSize(axis), rewriter, loc);
++keptAxisIndex;
sliceOffsets.push_back(axisIndex);
sliceSizes.push_back(rewriter.getIndexAttr(1));
}
insertOffsets.push_back(args.lane);
insertOffsets.append(inputType.getRank() - 1, rewriter.getIndexAttr(0));
Value slice = tensor::ExtractSliceOp::create(
rewriter, loc, sliceType, args.inputs.front(), sliceOffsets, sliceSizes, unitStrides);
Value reduced = spatial::SpatVAvgOp::create(rewriter, loc, leafType, slice).getResult();
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, reduced, args.outputs.front(), insertOffsets, insertSizes, unitStrides);
});
if (failed(batchOp))
return failure();
return (*batchOp).getResult(0);
}
static Value buildKeepdimsFromLanePackedBatch(Value batchValue,
RankedTensorType keepdimsType,
RankedTensorType compactKeptType,
ArrayRef<bool> reducedAxes,
ConversionPatternRewriter& rewriter,
Location loc) {
auto batchType = cast<RankedTensorType>(batchValue.getType());
if (batchType == keepdimsType)
return batchValue;
SmallVector<ReassociationIndices> collapseToFlat {{}};
for (int64_t axis = 0; axis < batchType.getRank(); ++axis)
collapseToFlat.front().push_back(axis);
SmallVector<ReassociationIndices> expandFlatToCompact(1);
for (int64_t axis = 0; axis < compactKeptType.getRank(); ++axis)
expandFlatToCompact.front().push_back(axis);
SmallVector<ReassociationIndices> expandCompactToKeepdims;
ReassociationIndices pendingLeadingReducedAxes;
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
if (isReduced) {
if (expandCompactToKeepdims.empty())
pendingLeadingReducedAxes.push_back(axis);
else
expandCompactToKeepdims.back().push_back(axis);
continue;
}
expandCompactToKeepdims.emplace_back();
auto& group = expandCompactToKeepdims.back();
group.append(pendingLeadingReducedAxes.begin(), pendingLeadingReducedAxes.end());
pendingLeadingReducedAxes.clear();
group.push_back(axis);
}
if (!pendingLeadingReducedAxes.empty())
expandCompactToKeepdims.back().append(pendingLeadingReducedAxes.begin(), pendingLeadingReducedAxes.end());
auto reshapeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {keepdimsType}, {}, ValueRange {batchValue}, [&](Value input) {
auto flatType =
RankedTensorType::get({batchType.getDimSize(0)}, batchType.getElementType(), batchType.getEncoding());
Value flat = tensor::CollapseShapeOp::create(rewriter, loc, flatType, input, collapseToFlat);
Value compact = flat;
if (compactKeptType != flatType)
compact = tensor::ExpandShapeOp::create(rewriter, loc, compactKeptType, flat, expandFlatToCompact);
Value keepdims = compact;
if (keepdimsType != compactKeptType)
keepdims = tensor::ExpandShapeOp::create(rewriter, loc, keepdimsType, compact, expandCompactToKeepdims);
spatial::SpatYieldOp::create(rewriter, loc, keepdims);
});
return reshapeCompute.getResult(0);
}
static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> reducedAxes) { static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> reducedAxes) {
SmallVector<ReassociationIndices> reassociation; SmallVector<ReassociationIndices> reassociation;
ReassociationIndices currentGroup; ReassociationIndices currentGroup;
@@ -72,70 +233,14 @@ static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<boo
return reassociation; return reassociation;
} }
static Value
createAverageCompute(Value input, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) {
auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x);
spatial::SpatYieldOp::create(rewriter, loc, avgOp.getResult());
});
return computeOp.getResult(0);
}
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
static Value buildReduceMeanKeepdims(Value input,
ArrayRef<bool> reducedAxes,
int64_t axis,
RankedTensorType leafType,
ConversionPatternRewriter& rewriter,
Location loc) {
int64_t rank = cast<RankedTensorType>(input.getType()).getRank();
if (axis == rank)
return createAverageCompute(input, leafType, rewriter, loc);
if (reducedAxes[axis])
return buildReduceMeanKeepdims(input, reducedAxes, axis + 1, leafType, rewriter, loc);
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
SmallVector<Value> reducedSlices;
reducedSlices.reserve(slices.size());
for (Value slice : slices)
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
return concatValues(reducedSlices, axis, rewriter, loc);
}
static Value squeezeReducedAxes(Value keepdimsValue, static Value squeezeReducedAxes(Value keepdimsValue,
RankedTensorType resultType, RankedTensorType resultType,
ArrayRef<bool> reducedAxes, ArrayRef<bool> reducedAxes,
ConversionPatternRewriter& rewriter, ConversionPatternRewriter& rewriter,
Location loc) { Location loc) {
if (resultType.getRank() == 0) { SmallVector<ReassociationIndices> reassociation =
SmallVector<Value> indices(cast<RankedTensorType>(keepdimsValue.getType()).getRank(), resultType.getRank() == 0 ? SmallVector<ReassociationIndices> {} : buildCollapseReassociation(reducedAxes);
arith::ConstantIndexOp::create(rewriter, loc, 0)); if (isCompileTimeComputable(keepdimsValue))
Value element = tensor::ExtractOp::create(rewriter, loc, keepdimsValue, indices);
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
}
auto reassociation = buildCollapseReassociation(reducedAxes);
if (isHostFoldableValue(keepdimsValue))
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult(); return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
auto squeezeCompute = auto squeezeCompute =
@@ -156,16 +261,33 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
auto resultType = dyn_cast<RankedTensorType>(reduceMeanOp.getReduced().getType()); auto resultType = dyn_cast<RankedTensorType>(reduceMeanOp.getReduced().getType());
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape()) if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
return failure(); return failure();
if (inputType.getRank() == 0) {
rewriter.replaceOp(reduceMeanOp, adaptor.getData());
return success();
}
SmallVector<int64_t> axes = normalizeAxes(reduceMeanOp.getAxesAttr(), inputType.getRank()); auto axes = normalizeAxesChecked(std::optional<ArrayAttr>(reduceMeanOp.getAxesAttr()), inputType.getRank());
SmallVector<bool> reducedAxes = buildReducedAxesMask(axes, inputType.getRank()); if (failed(axes))
return failure();
SmallVector<bool> reducedAxes = buildReducedAxesMask(*axes, inputType.getRank());
if (reducedAxes.empty() && inputType.getRank() != 0) if (reducedAxes.empty() && inputType.getRank() != 0)
return failure(); return failure();
Location loc = reduceMeanOp.getLoc(); Location loc = reduceMeanOp.getLoc();
RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType()); RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType());
RankedTensorType compactKeptType = getCompactKeptType(inputType, resultType.getElementType(), reducedAxes);
RankedTensorType keepdimsType = getKeepdimsType(inputType, resultType.getElementType(), reducedAxes);
int64_t laneCount = 1;
for (int64_t dim : compactKeptType.getShape())
laneCount *= dim;
RankedTensorType batchType = getLanePackedKeepdimsType(laneCount, leafType);
auto lanePackedKeepdims =
buildReduceMeanKeepdimsBatch(adaptor.getData(), reducedAxes, batchType, leafType, rewriter, loc);
if (failed(lanePackedKeepdims))
return failure();
Value reducedKeepdims = Value reducedKeepdims =
buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc); buildKeepdimsFromLanePackedBatch(*lanePackedKeepdims, keepdimsType, compactKeptType, reducedAxes, rewriter, loc);
if (reduceMeanOp.getKeepdims() != 0) { if (reduceMeanOp.getKeepdims() != 0) {
rewriter.replaceOp(reduceMeanOp, reducedKeepdims); rewriter.replaceOp(reduceMeanOp, reducedKeepdims);
@@ -12,6 +12,7 @@
#include <optional> #include <optional>
#include <type_traits> #include <type_traits>
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp" #include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
@@ -23,43 +24,26 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
template <typename ArrayAttrT> static Value materializeTileTensor(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
return cast<IntegerAttr>(arrayAttr[index]).getInt();
}
template <typename ArrayAttrT>
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
}
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
auto tileType = cast<RankedTensorType>(tile.getType()); auto tileType = cast<RankedTensorType>(tile.getType());
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType()); Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
return insertStaticSlice(rewriter, loc, tile, empty, getZeroOffsets(rewriter, tileType.getRank()));
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
sizes.reserve(tileType.getRank());
for (int64_t dimSize : tileType.getShape())
sizes.push_back(rewriter.getIndexAttr(dimSize));
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
} }
static Value static Value
createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) { createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
if (!useMinimumValue) if (!useMinimumValue)
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType)); return getOrCreateConstant(rewriter, anchorOp, rewriter.getZeroAttr(elementType), elementType);
if (auto floatType = dyn_cast<FloatType>(elementType)) { if (auto floatType = dyn_cast<FloatType>(elementType)) {
auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true); auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true);
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue)); return getOrCreateConstant(rewriter, anchorOp, rewriter.getFloatAttr(floatType, minValue), elementType);
} }
if (auto integerType = dyn_cast<IntegerType>(elementType)) { if (auto integerType = dyn_cast<IntegerType>(elementType)) {
auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth()); auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth());
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue)); return getOrCreateConstant(rewriter, anchorOp, rewriter.getIntegerAttr(integerType, minValue), elementType);
} }
llvm_unreachable("unsupported pool element type"); llvm_unreachable("unsupported pool element type");
@@ -166,7 +150,7 @@ static FailureOr<Value> createAverageScaleTensor(ConversionPatternRewriter& rewr
} }
auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues); auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues);
return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult(); return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scaleAttr, scaleType);
} }
template <typename PoolOp> template <typename PoolOp>
@@ -197,12 +181,12 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
const int64_t inputWidth = xType.getDimSize(3); const int64_t inputWidth = xType.getDimSize(3);
const int64_t outputHeight = outType.getDimSize(2); const int64_t outputHeight = outType.getDimSize(2);
const int64_t outputWidth = outType.getDimSize(3); const int64_t outputWidth = outType.getDimSize(3);
const int64_t kernelHeight = getI64(kernelAttr, 0); const int64_t kernelHeight = getI64Attr(kernelAttr, 0);
const int64_t kernelWidth = getI64(kernelAttr, 1); const int64_t kernelWidth = getI64Attr(kernelAttr, 1);
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1); const int64_t strideHeight = getOptionalI64Attr(poolOp.getStrides(), 0, 1);
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1); const int64_t strideWidth = getOptionalI64Attr(poolOp.getStrides(), 1, 1);
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1); const int64_t dilationHeight = getOptionalI64Attr(poolOp.getDilations(), 0, 1);
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1); const int64_t dilationWidth = getOptionalI64Attr(poolOp.getDilations(), 1, 1);
int64_t padTop = 0; int64_t padTop = 0;
int64_t padLeft = 0; int64_t padLeft = 0;
@@ -212,10 +196,10 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
if (auto padsAttr = poolOp.getPads()) { if (auto padsAttr = poolOp.getPads()) {
if (padsAttr->size() != 4) if (padsAttr->size() != 4)
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements."); return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
padTop = getI64(*padsAttr, 0); padTop = getI64Attr(*padsAttr, 0);
padLeft = getI64(*padsAttr, 1); padLeft = getI64Attr(*padsAttr, 1);
padBottom = getI64(*padsAttr, 2); padBottom = getI64Attr(*padsAttr, 2);
padRight = getI64(*padsAttr, 3); padRight = getI64Attr(*padsAttr, 3);
} }
else { else {
StringRef autoPad = poolOp.getAutoPad(); StringRef autoPad = poolOp.getAutoPad();
@@ -283,94 +267,111 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight); createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()); Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1); Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount); Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth); Value cOutputPatchCount = getOrCreateIndexConstant(rewriter, anchorOp, outputPatchCount);
Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth); Value cOutputPixelsPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, outputHeight * outputWidth);
Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight); Value cOutputWidth = getOrCreateIndexConstant(rewriter, anchorOp, outputWidth);
Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth); Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight);
Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth);
auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit}); auto outputLoop = buildNormalizedScfFor(
rewriter.setInsertionPointToStart(outputLoop.getBody()); rewriter,
loc,
c0,
cOutputPatchCount,
c1,
ValueRange {pooledOutputInit},
[&](OpBuilder&,
Location nestedLoc,
Value outputPatchIndex,
ValueRange iterArgs,
SmallVectorImpl<Value>& yielded) {
Value pooledOutputAcc = iterArgs.front();
Value batchIndex = arith::DivUIOp::create(rewriter, nestedLoc, outputPatchIndex, cOutputPixelsPerBatch);
Value batchPatchIndex =
arith::RemUIOp::create(rewriter, nestedLoc, outputPatchIndex, cOutputPixelsPerBatch);
Value outHeightIndex = arith::DivUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutputWidth);
Value outWidthIndex = arith::RemUIOp::create(rewriter, nestedLoc, batchPatchIndex, cOutputWidth);
Value windowBaseH = arith::MulIOp::create(rewriter, nestedLoc, outHeightIndex, cStrideHeight);
Value windowBaseW = arith::MulIOp::create(rewriter, nestedLoc, outWidthIndex, cStrideWidth);
Value outputPatchIndex = outputLoop.getInductionVar(); Value updatedOutput = pooledOutputAcc;
Value pooledOutputAcc = outputLoop.getRegionIterArgs().front(); for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) {
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize);
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType());
Value reducedWindow =
createPoolFillTensor(rewriter, nestedLoc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>);
Value batchIndex = arith::DivUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch); for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
Value batchPatchIndex = arith::RemUIOp::create(rewriter, loc, outputPatchIndex, cOutputPixelsPerBatch); Value paddedInH = windowBaseH;
Value outHeightIndex = arith::DivUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth); if (kernelH * dilationHeight != 0) {
Value outWidthIndex = arith::RemUIOp::create(rewriter, loc, batchPatchIndex, cOutputWidth); Value kernelHOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelH * dilationHeight);
Value windowBaseH = arith::MulIOp::create(rewriter, loc, outHeightIndex, cStrideHeight); paddedInH = arith::AddIOp::create(rewriter, nestedLoc, paddedInH, kernelHOffset);
Value windowBaseW = arith::MulIOp::create(rewriter, loc, outWidthIndex, cStrideWidth); }
Value updatedOutput = pooledOutputAcc; for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
for (int64_t channelTile = 0; channelTile < channelTileCount; ++channelTile) { Value paddedInW = windowBaseW;
const int64_t tileChannels = std::min<int64_t>(xbarSize, channels - channelTile * xbarSize); if (kernelW * dilationWidth != 0) {
auto tileType = RankedTensorType::get({1, tileChannels, 1, 1}, outType.getElementType()); Value kernelWOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelW * dilationWidth);
Value reducedWindow = paddedInW = arith::AddIOp::create(rewriter, nestedLoc, paddedInW, kernelWOffset);
createPoolFillTensor(rewriter, loc, tileType, std::is_same_v<PoolOp, ONNXMaxPoolSingleOutOp>); }
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) { SmallVector<OpFoldResult> offsets = {
Value paddedInH = windowBaseH; batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), paddedInH, paddedInW};
if (kernelH * dilationHeight != 0) { SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1),
Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight); rewriter.getIndexAttr(tileChannels),
paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset); rewriter.getIndexAttr(1),
} rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1),
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) { rewriter.getIndexAttr(1),
Value paddedInW = windowBaseW; rewriter.getIndexAttr(1),
if (kernelW * dilationWidth != 0) { rewriter.getIndexAttr(1)};
Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth); Value windowValue =
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset); tensor::ExtractSliceOp::create(rewriter, nestedLoc, tileType, paddedInput, offsets, sizes, strides);
windowValue = materializeTileTensor(rewriter, nestedLoc, windowValue);
reducedWindow = ReduceOp::create(rewriter, nestedLoc, tileType, reducedWindow, windowValue);
}
} }
SmallVector<OpFoldResult> offsets = { if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) {
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), paddedInH, paddedInW}; SmallVector<OpFoldResult> scaleOffsets = {rewriter.getIndexAttr(0),
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(channelTile * xbarSize),
rewriter.getIndexAttr(tileChannels), outHeightIndex,
rewriter.getIndexAttr(1), outWidthIndex};
rewriter.getIndexAttr(1)}; SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
SmallVector<OpFoldResult> strides = { rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> scaleStrides = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
Value scaleSlice = tensor::ExtractSliceOp::create(
rewriter, nestedLoc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
scaleSlice = materializeTileTensor(rewriter, nestedLoc, scaleSlice);
reducedWindow = spatial::SpatVMulOp::create(rewriter, nestedLoc, tileType, reducedWindow, scaleSlice);
}
SmallVector<OpFoldResult> outputOffsets = {
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> outputStrides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)}; rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value windowValue = updatedOutput = tensor::InsertSliceOp::create(
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides); rewriter, nestedLoc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue);
} }
} yielded.push_back(updatedOutput);
return success();
});
if (failed(outputLoop))
return failure();
if constexpr (std::is_same_v<PoolOp, ONNXAveragePoolOp>) { spatial::SpatYieldOp::create(rewriter, loc, outputLoop->results.front());
SmallVector<OpFoldResult> scaleOffsets = {
rewriter.getIndexAttr(0), rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
SmallVector<OpFoldResult> scaleSizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> scaleStrides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value scaleSlice = tensor::ExtractSliceOp::create(
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
}
SmallVector<OpFoldResult> outputOffsets = {
batchIndex, rewriter.getIndexAttr(channelTile * xbarSize), outHeightIndex, outWidthIndex};
SmallVector<OpFoldResult> outputSizes = {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(tileChannels),
rewriter.getIndexAttr(1),
rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> outputStrides = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
updatedOutput = tensor::InsertSliceOp::create(
rewriter, loc, reducedWindow, updatedOutput, outputOffsets, outputSizes, outputStrides);
}
scf::YieldOp::create(rewriter, loc, updatedOutput);
rewriter.setInsertionPointAfter(outputLoop);
spatial::SpatYieldOp::create(rewriter, loc, outputLoop.getResult(0));
return success(); return success();
}); });
if (failed(computeOp)) if (failed(computeOp))
@@ -1,9 +1,11 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -12,61 +14,94 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; } static Value buildLoopSoftmaxSlice(Value input,
Value accumulator,
RankedTensorType inputType,
ArrayRef<Value> outerIndices,
ConversionPatternRewriter& rewriter,
Location loc) {
int64_t rank = inputType.getRank();
SmallVector<int64_t> sliceShape(static_cast<size_t>(rank - 1), 1);
sliceShape.push_back(inputType.getDimSize(rank - 1));
auto sliceType = RankedTensorType::get(sliceShape, inputType.getElementType(), inputType.getEncoding());
static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) { SmallVector<OpFoldResult> offsets;
SmallVector<int64_t> permutedShape; SmallVector<OpFoldResult> sizes;
permutedShape.reserve(permutation.size()); SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, rank);
for (int64_t axis : permutation) offsets.reserve(rank);
permutedShape.push_back(shape[axis]); sizes.reserve(rank);
return permutedShape;
for (Value outerIndex : outerIndices) {
offsets.push_back(outerIndex);
sizes.push_back(rewriter.getIndexAttr(1));
}
offsets.push_back(rewriter.getIndexAttr(0));
sizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(rank - 1)));
Value inputSlice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
Value softmaxSlice = spatial::SpatSoftmaxOp::create(rewriter, loc, sliceType, inputSlice).getResult();
return tensor::InsertSliceOp::create(rewriter, loc, softmaxSlice, accumulator, offsets, sizes, strides);
} }
static Value createSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) { static FailureOr<Value> buildLoopSoftmaxNest(Value input,
Value accumulator,
RankedTensorType inputType,
int64_t axis,
SmallVectorImpl<Value>& outerIndices,
ConversionPatternRewriter& rewriter,
Location loc) {
if (axis == inputType.getRank() - 1)
return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cUpper = getOrCreateIndexConstant(rewriter, anchorOp, inputType.getDimSize(axis));
auto loop = buildNormalizedScfFor(
rewriter,
loc,
c0,
cUpper,
c1,
ValueRange {accumulator},
[&](OpBuilder& builder, Location nestedLoc, Value loopIndex, ValueRange iterArgs, SmallVectorImpl<Value>& yielded) {
outerIndices.push_back(loopIndex);
auto updatedAccumulator =
buildLoopSoftmaxNest(input, iterArgs.front(), inputType, axis + 1, outerIndices, rewriter, nestedLoc);
outerIndices.pop_back();
if (failed(updatedAccumulator))
return failure();
yielded.push_back(*updatedAccumulator);
return success();
});
if (failed(loop))
return failure();
return loop->results.front();
}
static FailureOr<Value> createLoopSoftmaxCompute(Value input, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); auto inputType = cast<RankedTensorType>(input.getType());
constexpr size_t numInputs = 1; constexpr size_t numInputs = 1;
auto computeOp = auto computeOp = createSpatCompute<numInputs>(
createSpatCompute<numInputs>(rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) { rewriter, loc, TypeRange {inputType}, {}, ValueRange {input}, [&](Value x) -> LogicalResult {
auto softmaxOp = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x); if (inputType.getRank() == 1) {
spatial::SpatYieldOp::create(rewriter, loc, softmaxOp.getResult()); Value softmax = spatial::SpatSoftmaxOp::create(rewriter, loc, inputType, x).getResult();
spatial::SpatYieldOp::create(rewriter, loc, softmax);
return success();
}
Value outputInit = tensor::EmptyOp::create(rewriter, loc, inputType.getShape(), inputType.getElementType());
SmallVector<Value> outerIndices;
auto result = buildLoopSoftmaxNest(x, outputInit, inputType, /*axis=*/0, outerIndices, rewriter, loc);
if (failed(result))
return failure();
spatial::SpatYieldOp::create(rewriter, loc, *result);
return success();
}); });
return computeOp.getResult(0); if (failed(computeOp))
} return failure();
return computeOp->getResult(0);
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
static Value
buildSoftmax(Value input, int64_t softmaxAxis, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
if (axis == inputType.getRank())
return createSoftmaxCompute(input, rewriter, loc);
if (axis == softmaxAxis)
return buildSoftmax(input, softmaxAxis, axis + 1, rewriter, loc);
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
SmallVector<Value> rebuiltSlices;
rebuiltSlices.reserve(slices.size());
for (Value slice : slices)
rebuiltSlices.push_back(buildSoftmax(slice, softmaxAxis, axis + 1, rewriter, loc));
return concatValues(rebuiltSlices, axis, rewriter, loc);
} }
struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> { struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
@@ -79,45 +114,40 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
if (!inputType || !inputType.hasStaticShape()) if (!inputType || !inputType.hasStaticShape())
return failure(); return failure();
int64_t axis = normalizeAxis(softmaxOp.getAxis(), inputType.getRank()); auto axis = normalizeAxisChecked(softmaxOp.getAxis(), inputType.getRank());
if (axis < 0 || axis >= inputType.getRank()) if (failed(axis))
return failure(); return failure();
Value input = adaptor.getInput(); Value input = adaptor.getInput();
Value result; Value result;
if (axis == inputType.getRank() - 1) { if (*axis == inputType.getRank() - 1) {
result = buildSoftmax(input, axis, /*axis=*/0, rewriter, softmaxOp.getLoc()); auto computed = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
if (failed(computed))
return failure();
result = *computed;
} }
else { else {
SmallVector<int64_t> permutation; SmallVector<int64_t> permutation;
permutation.reserve(inputType.getRank()); permutation.reserve(inputType.getRank());
for (int64_t dim = 0; dim < inputType.getRank(); ++dim) for (int64_t dim = 0; dim < inputType.getRank(); ++dim)
if (dim != axis) if (dim != *axis)
permutation.push_back(dim); permutation.push_back(dim);
permutation.push_back(axis); permutation.push_back(*axis);
SmallVector<int64_t> inversePermutation = invertPermutation(permutation);
SmallVector<int64_t> inversePermutation(inputType.getRank());
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
auto transposedType = RankedTensorType::get( auto transposedType = RankedTensorType::get(
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding()); permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
auto preTransposeCompute = Value transposedInput =
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) { ONNXTransposeOp::create(
Value transposed = ONNXTransposeOp::create( rewriter, softmaxOp.getLoc(), transposedType, input, rewriter.getI64ArrayAttr(permutation))
rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation)); .getResult();
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed); auto transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
}); if (failed(transposedResult))
Value transposedInput = preTransposeCompute.getResult(0); return failure();
Value transposedResult = buildSoftmax( result =
transposedInput, /*softmaxAxis=*/inputType.getRank() - 1, /*axis=*/0, rewriter, softmaxOp.getLoc()); ONNXTransposeOp::create(
auto postTransposeCompute = rewriter, softmaxOp.getLoc(), inputType, *transposedResult, rewriter.getI64ArrayAttr(inversePermutation))
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) { .getResult();
Value transposed = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation));
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
});
result = postTransposeCompute.getResult(0);
} }
rewriter.replaceOp(softmaxOp, result); rewriter.replaceOp(softmaxOp, result);
@@ -0,0 +1,288 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isWeightMaterializationHelperUser(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(op);
}
static bool canPromoteInputBlockArgument(BlockArgument arg) {
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
}
static bool canPromoteInputBlockArgument(std::optional<BlockArgument> arg) {
return arg && canPromoteInputBlockArgument(*arg);
}
static bool isDirectConstantValue(Value value) {
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
}
struct PromotedOperands {
SmallVector<bool> promoteInput;
SmallVector<Value> newWeights;
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
};
template <typename ComputeOpTy>
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
continue;
return true;
}
return false;
}
template <typename ComputeOpTy>
static FailureOr<PromotedOperands> computePromotedOperands(ComputeOpTy compute) {
PromotedOperands promoted;
promoted.promoteInput.assign(compute.getInputs().size(), false);
promoted.newWeights.append(compute.getWeights().begin(), compute.getWeights().end());
promoted.newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
promoted.newInputs.reserve(compute.getInputs().size());
promoted.newInputTypes.reserve(compute.getInputs().size());
promoted.newInputLocs.reserve(compute.getInputs().size());
bool needsRewrite = false;
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
goto keep_input;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
goto keep_input;
promoted.promoteInput[inputIdx] = true;
promoted.newWeights.push_back(input);
needsRewrite = true;
continue;
keep_input:
promoted.newInputs.push_back(input);
promoted.newInputTypes.push_back(input.getType());
promoted.newInputLocs.push_back(input.getLoc());
}
if (!needsRewrite)
return failure();
return promoted;
}
template <typename ComputeOpTy>
static LogicalResult mapPromotedInputArguments(ComputeOpTy compute,
const PromotedOperands& promoted,
IRRewriter& bodyRewriter,
IRMapping& mapper,
std::function<std::optional<BlockArgument>(size_t)> getNewInputArg,
PatternRewriter& rewriter) {
size_t newInputIdx = 0;
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
auto oldArg = compute.getInputArgument(oldInputIdx);
if (!oldArg)
return rewriter.notifyMatchFailure(compute, "missing input block argument during rewrite");
if (!promoted.promoteInput[oldInputIdx]) {
auto newInputArg = getNewInputArg(newInputIdx++);
if (!newInputArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten input block argument");
mapper.map(*oldArg, *newInputArg);
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(*oldArg, *clonedValue);
}
return success();
}
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
auto promoted = computePromotedOperands(compute);
if (failed(promoted))
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
Block& oldBlock = compute.getBody().front();
rewriter.setInsertionPointAfter(compute);
auto newCompute = spatial::SpatCompute::create(
rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
for (Value weight : promoted->newWeights) {
newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc());
}
llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
auto oldWeightArg = compute.getWeightArgument(weightIndex);
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
if (!oldWeightArg || !newWeightArg)
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
mapper.map(*oldWeightArg, *newWeightArg);
}
if (failed(mapPromotedInputArguments(
compute,
*promoted,
bodyRewriter,
mapper,
[&](size_t index) { return newCompute.getInputArgument(index); },
rewriter)))
return failure();
for (Operation& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
};
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
auto promoted = computePromotedOperands(compute);
if (failed(promoted))
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
Block& oldBlock = compute.getBody().front();
rewriter.setInsertionPointAfter(compute);
auto laneCountAttr = pim::getCheckedI32Attr(
rewriter, compute, static_cast<uint64_t>(compute.getLaneCount()), "promoted compute_batch lane count");
if (failed(laneCountAttr))
return failure();
auto newCompute = spatial::SpatComputeBatch::create(
rewriter, compute.getLoc(), compute.getResultTypes(), *laneCountAttr, promoted->newWeights, promoted->newInputs);
auto laneArg = compute.getLaneArgument();
if (!laneArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size()
+ compute.getNumResults());
newBlockArgLocs.reserve(1 + promoted->newWeights.size() + promoted->newInputLocs.size() + compute.getNumResults());
newBlockArgTypes.push_back(laneArg->getType());
newBlockArgLocs.push_back(laneArg->getLoc());
for (Value weight : promoted->newWeights) {
newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc());
}
llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
auto outputArg = compute.getOutputArgument(resultIndex);
if (!outputArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument");
newBlockArgTypes.push_back(resultType);
newBlockArgLocs.push_back(outputArg->getLoc());
}
auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
auto newLaneArg = newCompute.getLaneArgument();
if (!newLaneArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch lane block argument");
mapper.map(*laneArg, *newLaneArg);
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
auto oldWeightArg = compute.getWeightArgument(weightIndex);
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
if (!oldWeightArg || !newWeightArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
mapper.map(*oldWeightArg, *newWeightArg);
}
if (failed(mapPromotedInputArguments(
compute,
*promoted,
bodyRewriter,
mapper,
[&](size_t index) { return newCompute.getInputArgument(index); },
rewriter)))
return failure();
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
auto outputArg = compute.getOutputArgument(resultIndex);
if (!outputArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
mapper.map(*outputArg,
newBlock->getArgument(1 + promoted->newWeights.size() + promoted->newInputs.size() + resultIndex));
}
for (Operation& op : oldBlock)
rewriter.clone(op, mapper);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
};
} // namespace
void populateWeightPromotionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
}
void annotateWeightsConstants(func::FuncOp funcOp) {
funcOp.walk([&](arith::ConstantOp constantOp) {
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
markWeightAlways(constantOp);
});
}
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
} // namespace onnx_mlir
@@ -1,6 +1,5 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
using namespace mlir; using namespace mlir;
@@ -12,14 +11,12 @@ namespace {
} // namespace } // namespace
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) { void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<onnxToArithConstant>(ctx); patterns.add<onnxToArithConstant>(ctx);
patterns.add<convAddToConvWithBiasLeft>(ctx); patterns.add<convAddToConvWithBiasLeft>(ctx);
patterns.add<convAddToConvWithBiasRight>(ctx); patterns.add<convAddToConvWithBiasRight>(ctx);
patterns.add<matMulAddToGemm>(ctx); patterns.add<matMulAddToGemm>(ctx);
patterns.add<matMulToGemm>(ctx);
patterns.add<removeFlattenSameShape>(ctx); patterns.add<removeFlattenSameShape>(ctx);
populateMatMulRewritePatterns(patterns, ctx);
} }
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -3,7 +3,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -20,7 +20,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
auto inputs = adaptor.getInputs(); auto inputs = adaptor.getInputs();
int64_t axis = adaptor.getAxis(); int64_t axis = adaptor.getAxis();
if (llvm::all_of(inputs, isHostFoldableValue)) { if (llvm::all_of(inputs, isCompileTimeComputable)) {
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs)); rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
return success(); return success();
} }
@@ -6,7 +6,7 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -15,24 +15,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
static int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
static Value
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(inputType.getRank());
for (int64_t dim : inputType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(1);
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
}
static Value concatGatherSlices(Value data, static Value concatGatherSlices(Value data,
int64_t axis, int64_t axis,
ArrayRef<int64_t> indices, ArrayRef<int64_t> indices,
@@ -45,7 +27,7 @@ static Value concatGatherSlices(Value data,
int64_t normalizedIndex = normalizeIndex(index, axisDim); int64_t normalizedIndex = normalizeIndex(index, axisDim);
if (normalizedIndex < 0 || normalizedIndex >= axisDim) if (normalizedIndex < 0 || normalizedIndex >= axisDim)
return {}; return {};
slices.push_back(extractSliceAt(data, axis, normalizedIndex, rewriter, loc)); slices.push_back(extractAxisSlice(rewriter, loc, data, axis, normalizedIndex, /*size=*/1));
} }
if (slices.empty()) if (slices.empty())
return {}; return {};
@@ -96,11 +78,11 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
return failure(); return failure();
int64_t rank = dataType.getRank(); int64_t rank = dataType.getRank();
int64_t axis = normalizeAxis(gatherOp.getAxis(), rank); auto axis = normalizeAxisChecked(gatherOp.getAxis(), rank);
if (axis < 0 || axis >= rank) if (failed(axis))
return failure(); return failure();
int64_t axisDim = dataType.getShape()[axis]; int64_t axisDim = dataType.getShape()[*axis];
if (axisDim <= 0) if (axisDim <= 0)
return failure(); return failure();
@@ -116,7 +98,7 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
[&](Value data) -> LogicalResult { [&](Value data) -> LogicalResult {
Value result; Value result;
if (indicesType.getRank() == 1) { if (indicesType.getRank() == 1) {
result = concatGatherSlices(data, axis, flatIndices, axisDim, rewriter, loc); result = concatGatherSlices(data, *axis, flatIndices, axisDim, rewriter, loc);
} }
else if (indicesType.getRank() == 2) { else if (indicesType.getRank() == 2) {
int64_t rowCount = indicesType.getShape()[0]; int64_t rowCount = indicesType.getShape()[0];
@@ -125,12 +107,13 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
rows.reserve(rowCount); rows.reserve(rowCount);
for (int64_t row = 0; row < rowCount; ++row) { for (int64_t row = 0; row < rowCount; ++row) {
ArrayRef<int64_t> rowIndices(flatIndices.data() + row * rowWidth, rowWidth); ArrayRef<int64_t> rowIndices(flatIndices.data() + row * rowWidth, rowWidth);
Value gatheredRow = concatGatherSlices(data, axis, rowIndices, axisDim, rewriter, loc); Value gatheredRow =
concatGatherSlices(data, *axis, rowIndices, axisDim, rewriter, loc);
if (!gatheredRow) if (!gatheredRow)
return failure(); return failure();
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc)); rows.push_back(addLeadingGatherDim(gatheredRow, *axis, rewriter, loc));
} }
result = createSpatConcat(rewriter, loc, /*axis=*/axis, rows); result = createSpatConcat(rewriter, loc, /*axis=*/*axis, rows);
} }
else { else {
return failure(); return failure();
@@ -4,8 +4,8 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -14,10 +14,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape, static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> resultShape, ArrayRef<int64_t> resultShape,
SmallVector<ReassociationIndices>& reassociation) { SmallVector<ReassociationIndices>& reassociation) {
@@ -80,6 +76,22 @@ static bool inferExpandReassociation(ArrayRef<int64_t> sourceShape,
return sourceIdx == sourceShape.size() && resultIdx == resultShape.size(); return sourceIdx == sourceShape.size() && resultIdx == resultShape.size();
} }
static SmallVector<ReassociationIndices> getCollapseTo1DReassociation(size_t rank) {
SmallVector<ReassociationIndices> reassociation(1);
reassociation.front().reserve(rank);
for (size_t dim = 0; dim < rank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
return reassociation;
}
static SmallVector<ReassociationIndices> getExpandFrom1DReassociation(size_t rank) {
SmallVector<ReassociationIndices> reassociation(1);
reassociation.front().reserve(rank);
for (size_t dim = 0; dim < rank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
return reassociation;
}
struct Reshape : OpConversionPattern<ONNXReshapeOp> { struct Reshape : OpConversionPattern<ONNXReshapeOp> {
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
@@ -90,7 +102,7 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType()); auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape()) if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure(); return failure();
if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape())) if (!hasStaticPositiveShape(sourceType) || !hasStaticPositiveShape(resultType))
return failure(); return failure();
if (sourceType == resultType) { if (sourceType == resultType) {
@@ -99,17 +111,9 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
} }
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult { auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
if (isHostFoldableValue(adaptor.getData())) { Value reshaped =
rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData())); materializeOrComputeUnary(adaptor.getData(), resultType, rewriter, reshapeOp.getLoc(), buildReshape);
return success(); rewriter.replaceOp(reshapeOp, reshaped);
}
auto computeOp = createSpatCompute<1>(
rewriter, reshapeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
Value reshaped = buildReshape(data);
spatial::SpatYieldOp::create(rewriter, reshapeOp.getLoc(), reshaped);
});
rewriter.replaceOp(reshapeOp, computeOp.getResults());
return success(); return success();
}; };
@@ -126,6 +130,23 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation); return tensor::ExpandShapeOp::create(rewriter, reshapeOp.getLoc(), resultType, data, reassociation);
}); });
if (sourceType.getNumElements() != resultType.getNumElements())
return failure();
return replaceWithReshape([&](Value data) -> Value {
Value reshaped = data;
if (sourceType.getRank() != 1) {
auto flatType = RankedTensorType::get({sourceType.getNumElements()}, sourceType.getElementType());
reshaped = tensor::CollapseShapeOp::create(
rewriter, reshapeOp.getLoc(), flatType, reshaped, getCollapseTo1DReassociation(sourceType.getRank()));
}
if (resultType.getRank() == 1)
return reshaped;
return tensor::ExpandShapeOp::create(
rewriter, reshapeOp.getLoc(), resultType, reshaped, getExpandFrom1DReassociation(resultType.getRank()))
.getResult();
});
return failure(); return failure();
} }
}; };
@@ -1,12 +1,13 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include <algorithm> #include "src/Accelerators/PIM/Common/IR/LoopUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -15,42 +16,127 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static Value static Value buildNearestAsymmetricIndex(
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) { Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType()); Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0)); Value cInputDim = getOrCreateIndexConstant(rewriter, anchorOp, inputDim);
SmallVector<OpFoldResult> sizes; Value cOutputDim = getOrCreateIndexConstant(rewriter, anchorOp, outputDim);
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1)); Value cInputDimLast = getOrCreateIndexConstant(rewriter, anchorOp, inputDim - 1);
sizes.reserve(inputType.getRank()); Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim);
for (int64_t dim : inputType.getShape()) Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim);
sizes.push_back(rewriter.getIndexAttr(dim)); return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast);
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(1);
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
} }
static int64_t nearestAsymmetricIndex(int64_t outputIndex, int64_t inputDim, int64_t outputDim) { static FailureOr<Value> buildNearestResizeLoop(Value input,
return std::min<int64_t>((outputIndex * inputDim) / outputDim, inputDim - 1); RankedTensorType inputType,
} RankedTensorType resultType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto elemType = resultType.getElementType();
SmallVector<int64_t> unitShape(resultType.getRank(), 1);
auto unitTensorType = RankedTensorType::get(unitShape, elemType);
static Value buildNearestResize(Value input, SmallVector<OpFoldResult> unitSizes(resultType.getRank(), rewriter.getIndexAttr(1));
ArrayRef<int64_t> inputShape, SmallVector<OpFoldResult> unitStrides(resultType.getRank(), rewriter.getIndexAttr(1));
ArrayRef<int64_t> outputShape,
int64_t axis,
ConversionPatternRewriter& rewriter,
Location loc) {
if (axis == static_cast<int64_t>(outputShape.size()))
return input;
SmallVector<Value> slices; Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
slices.reserve(outputShape[axis]); Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
for (int64_t outputIndex = 0; outputIndex < outputShape[axis]; ++outputIndex) { Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
int64_t inputIndex = nearestAsymmetricIndex(outputIndex, inputShape[axis], outputShape[axis]); Value cOutputN = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(0));
Value slice = extractSliceAt(input, axis, inputIndex, rewriter, loc); Value cOutputC = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(1));
slices.push_back(buildNearestResize(slice, inputShape, outputShape, axis + 1, rewriter, loc)); Value cOutputH = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(2));
} Value cOutputW = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(3));
return createSpatConcat(rewriter, loc, axis, slices); Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType);
auto batchLoop = buildNormalizedScfFor(
rewriter,
loc,
c0,
cOutputN,
c1,
ValueRange {outputInit},
[&](OpBuilder&, Location nestedLoc, Value outputN, ValueRange batchIterArgs, SmallVectorImpl<Value>& batchYielded) {
Value outputBatchAcc = batchIterArgs.front();
Value inputN =
buildNearestAsymmetricIndex(outputN, inputType.getDimSize(0), resultType.getDimSize(0), rewriter, nestedLoc);
auto channelLoop = buildNormalizedScfFor(
rewriter,
nestedLoc,
c0,
cOutputC,
c1,
ValueRange {outputBatchAcc},
[&](OpBuilder&,
Location channelLoc,
Value outputC,
ValueRange channelIterArgs,
SmallVectorImpl<Value>& channelYielded) {
Value outputChannelAcc = channelIterArgs.front();
Value inputC = buildNearestAsymmetricIndex(
outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, channelLoc);
auto heightLoop = buildNormalizedScfFor(
rewriter,
channelLoc,
c0,
cOutputH,
c1,
ValueRange {outputChannelAcc},
[&](OpBuilder&,
Location heightLoc,
Value outputH,
ValueRange heightIterArgs,
SmallVectorImpl<Value>& heightYielded) {
Value outputHeightAcc = heightIterArgs.front();
Value inputH = buildNearestAsymmetricIndex(
outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, heightLoc);
auto widthLoop = buildNormalizedScfFor(
rewriter,
heightLoc,
c0,
cOutputW,
c1,
ValueRange {outputHeightAcc},
[&](OpBuilder&,
Location widthLoc,
Value outputW,
ValueRange widthIterArgs,
SmallVectorImpl<Value>& widthYielded) {
Value outputWidthAcc = widthIterArgs.front();
Value inputW = buildNearestAsymmetricIndex(
outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, widthLoc);
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
Value inputSlice = tensor::ExtractSliceOp::create(
rewriter, widthLoc, unitTensorType, input, inputOffsets, unitSizes, unitStrides);
SmallVector<OpFoldResult> outputOffsets = {outputN, outputC, outputH, outputW};
Value updatedOutput = tensor::InsertSliceOp::create(
rewriter, widthLoc, inputSlice, outputWidthAcc, outputOffsets, unitSizes, unitStrides);
widthYielded.push_back(updatedOutput);
return success();
});
if (failed(widthLoop))
return failure();
heightYielded.push_back(widthLoop->results.front());
return success();
});
if (failed(heightLoop))
return failure();
channelYielded.push_back(heightLoop->results.front());
return success();
});
if (failed(channelLoop))
return failure();
batchYielded.push_back(channelLoop->results.front());
return success();
});
if (failed(batchLoop))
return failure();
return batchLoop->results.front();
} }
struct Resize : OpConversionPattern<ONNXResizeOp> { struct Resize : OpConversionPattern<ONNXResizeOp> {
@@ -62,23 +148,30 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
auto inputType = dyn_cast<RankedTensorType>(adaptor.getX().getType()); auto inputType = dyn_cast<RankedTensorType>(adaptor.getX().getType());
auto resultType = dyn_cast<RankedTensorType>(resizeOp.getY().getType()); auto resultType = dyn_cast<RankedTensorType>(resizeOp.getY().getType());
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape()) if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
return failure(); return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires static ranked tensor types.");
if (inputType.getRank() != 4 || resultType.getRank() != 4)
return rewriter.notifyMatchFailure(resizeOp, "resize lowering currently supports only rank-4 NCHW tensors.");
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric" if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|| resizeOp.getNearestMode() != "floor") || resizeOp.getNearestMode() != "floor")
return failure(); return rewriter.notifyMatchFailure(resizeOp,
"resize lowering currently supports only nearest + asymmetric + floor.");
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; }) if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; })) || llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
return failure(); return rewriter.notifyMatchFailure(resizeOp, "resize lowering requires positive static dimensions.");
auto computeOp = auto computeOp = createSpatCompute<1>(
createSpatCompute<1>(rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) { rewriter, resizeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getX(), [&](Value x) -> LogicalResult {
Value result = auto result = buildNearestResizeLoop(x, inputType, resultType, rewriter, resizeOp.getLoc());
buildNearestResize(x, inputType.getShape(), resultType.getShape(), /*axis=*/0, rewriter, resizeOp.getLoc()); if (failed(result))
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), result); return failure();
spatial::SpatYieldOp::create(rewriter, resizeOp.getLoc(), *result);
return success();
}); });
rewriter.replaceOp(resizeOp, computeOp.getResults()); if (failed(computeOp))
return failure();
rewriter.replaceOp(resizeOp, computeOp->getResults());
return success(); return success();
} }
}; };
@@ -0,0 +1,189 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <optional>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static FailureOr<SmallVector<int64_t>> getConstantIntValues(Value value) {
auto denseAttr = dyn_cast_or_null<DenseIntElementsAttr>(getHostConstDenseElementsAttr(value));
if (!denseAttr)
return failure();
return SmallVector<int64_t>(denseAttr.getValues<int64_t>().begin(), denseAttr.getValues<int64_t>().end());
}
static bool isNoneValueLike(Value value) { return isa_and_nonnull<ONNXNoneOp>(value.getDefiningOp()); }
static FailureOr<Value> buildSlice(Value data,
RankedTensorType dataType,
RankedTensorType resultType,
ArrayRef<int64_t> starts,
ArrayRef<int64_t> ends,
std::optional<ArrayRef<int64_t>> axes,
std::optional<ArrayRef<int64_t>> steps,
ConversionPatternRewriter& rewriter,
Location loc) {
int64_t rank = dataType.getRank();
if (!dataType.hasStaticShape() || !resultType.hasStaticShape() || resultType.getRank() != rank)
return failure();
if (starts.size() != ends.size())
return failure();
if (axes && axes->size() != starts.size())
return failure();
if (steps && steps->size() != starts.size())
return failure();
SmallVector<int64_t> normalizedAxes;
if (axes) {
SmallVector<bool> seenAxes(rank, false);
normalizedAxes.reserve(axes->size());
for (int64_t axis : *axes) {
auto normalizedAxis = normalizeAxisChecked(axis, rank);
if (failed(normalizedAxis))
return failure();
if (seenAxes[*normalizedAxis])
return failure();
seenAxes[*normalizedAxis] = true;
normalizedAxes.push_back(*normalizedAxis);
}
}
else {
if (starts.size() > static_cast<size_t>(rank))
return failure();
normalizedAxes.reserve(starts.size());
for (size_t i = 0; i < starts.size(); ++i)
normalizedAxes.push_back(static_cast<int64_t>(i));
}
SmallVector<int64_t> normalizedSteps;
if (steps)
normalizedSteps.assign(steps->begin(), steps->end());
else
normalizedSteps.assign(starts.size(), 1);
SmallVector<int64_t> computedShape(dataType.getShape().begin(), dataType.getShape().end());
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, rank);
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, dataType.getShape());
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, rank);
for (auto [sliceIndex, axis] : llvm::enumerate(normalizedAxes)) {
int64_t step = normalizedSteps[sliceIndex];
if (step <= 0)
return failure();
int64_t dimSize = dataType.getShape()[axis];
int64_t start = starts[sliceIndex];
int64_t end = ends[sliceIndex];
start = normalizeIndex(start, dimSize);
end = normalizeIndex(end, dimSize);
start = std::clamp(start, int64_t {0}, dimSize);
end = std::clamp(end, int64_t {0}, dimSize);
int64_t extent = std::max(end - start, int64_t {0});
int64_t size = (extent + step - 1) / step;
offsets[axis] = rewriter.getIndexAttr(start);
sizes[axis] = rewriter.getIndexAttr(size);
strides[axis] = rewriter.getIndexAttr(step);
computedShape[axis] = size;
}
if (llvm::ArrayRef(computedShape) != resultType.getShape())
return failure();
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, data, offsets, sizes, strides).getResult();
}
struct Slice final : OpConversionPattern<ONNXSliceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXSliceOp sliceOp,
ONNXSliceOpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto dataType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
auto resultType = dyn_cast<RankedTensorType>(sliceOp.getResult().getType());
if (!dataType || !resultType || !dataType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
auto starts = getConstantIntValues(adaptor.getStarts());
auto ends = getConstantIntValues(adaptor.getEnds());
if (failed(starts))
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant starts");
if (failed(ends))
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant ends");
std::optional<SmallVector<int64_t>> axes;
if (!isNoneValueLike(adaptor.getAxes())) {
auto parsedAxes = getConstantIntValues(adaptor.getAxes());
if (failed(parsedAxes))
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant axes when present");
axes = std::move(*parsedAxes);
}
std::optional<SmallVector<int64_t>> steps;
if (!isNoneValueLike(adaptor.getSteps())) {
auto parsedSteps = getConstantIntValues(adaptor.getSteps());
if (failed(parsedSteps))
return rewriter.notifyMatchFailure(sliceOp, "requires compile-time constant steps when present");
steps = std::move(*parsedSteps);
if (llvm::any_of(*steps, [](int64_t step) { return step <= 0; }))
return rewriter.notifyMatchFailure(sliceOp, "supports only positive constant steps");
}
ArrayRef<int64_t> startsRef = *starts;
ArrayRef<int64_t> endsRef = *ends;
std::optional<ArrayRef<int64_t>> axesRef = axes ? std::optional<ArrayRef<int64_t>>(ArrayRef<int64_t>(*axes))
: std::nullopt;
std::optional<ArrayRef<int64_t>> stepsRef = steps ? std::optional<ArrayRef<int64_t>>(ArrayRef<int64_t>(*steps))
: std::nullopt;
Location loc = sliceOp.getLoc();
auto tryBuildSlice = [&](Value data) {
return buildSlice(data, dataType, resultType, startsRef, endsRef, axesRef, stepsRef, rewriter, loc);
};
if (isCompileTimeComputable(adaptor.getData())) {
auto sliced = tryBuildSlice(adaptor.getData());
if (failed(sliced))
return rewriter.notifyMatchFailure(sliceOp, "failed to normalize static slice parameters");
rewriter.replaceOp(sliceOp, *sliced);
return success();
}
auto computeOp =
createSpatCompute<1>(rewriter, loc, TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
auto sliced = tryBuildSlice(data);
if (failed(sliced))
return failure();
spatial::SpatYieldOp::create(rewriter, loc, *sliced);
return success();
});
if (failed(computeOp))
return rewriter.notifyMatchFailure(sliceOp, "failed to build runtime tensor.extract_slice lowering");
rewriter.replaceOp(sliceOp, computeOp->getResults());
return success();
}
};
} // namespace
void populateSlicePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { patterns.add<Slice>(ctx); }
} // namespace onnx_mlir
@@ -2,8 +2,8 @@
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp" #include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -12,25 +12,6 @@ using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
static Value extractSliceAt(
Value input, int64_t axis, int64_t offset, int64_t size, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(inputType.getRank());
for (int64_t dim : inputType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(size);
SmallVector<int64_t> resultShape(inputType.getShape());
resultShape[axis] = size;
auto resultType = RankedTensorType::get(resultShape, inputType.getElementType());
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, input, offsets, sizes, strides);
}
struct Split : OpConversionPattern<ONNXSplitOp> { struct Split : OpConversionPattern<ONNXSplitOp> {
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
@@ -41,8 +22,8 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
return failure(); return failure();
int64_t rank = inputType.getRank(); int64_t rank = inputType.getRank();
int64_t axis = normalizeAxis(splitOp.getAxis(), rank); auto axis = normalizeAxisChecked(splitOp.getAxis(), rank);
if (axis < 0 || axis >= rank) if (failed(axis))
return failure(); return failure();
SmallVector<Value> outputs; SmallVector<Value> outputs;
@@ -58,12 +39,12 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
if (!resultType || !resultType.hasStaticShape()) if (!resultType || !resultType.hasStaticShape())
return failure(); return failure();
resultTypes.push_back(resultType); resultTypes.push_back(resultType);
sliceSizes.push_back(resultType.getShape()[axis]); sliceSizes.push_back(resultType.getShape()[*axis]);
} }
if (isHostFoldableValue(adaptor.getInput())) { if (isCompileTimeComputable(adaptor.getInput())) {
for (int64_t sliceSize : sliceSizes) { for (int64_t sliceSize : sliceSizes) {
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc())); outputs.push_back(extractAxisSlice(rewriter, splitOp.getLoc(), adaptor.getInput(), *axis, offset, sliceSize));
offset += sliceSize; offset += sliceSize;
} }
rewriter.replaceOp(splitOp, outputs); rewriter.replaceOp(splitOp, outputs);
@@ -76,7 +57,8 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
runtimeOutputs.reserve(resultTypes.size()); runtimeOutputs.reserve(resultTypes.size());
int64_t runtimeOffset = 0; int64_t runtimeOffset = 0;
for (int64_t sliceSize : sliceSizes) { for (int64_t sliceSize : sliceSizes) {
runtimeOutputs.push_back(extractSliceAt(input, axis, runtimeOffset, sliceSize, rewriter, splitOp.getLoc())); runtimeOutputs.push_back(
extractAxisSlice(rewriter, splitOp.getLoc(), input, *axis, runtimeOffset, sliceSize));
runtimeOffset += sliceSize; runtimeOffset += sliceSize;
} }
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), runtimeOutputs); spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), runtimeOutputs);
@@ -0,0 +1,135 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isInsideSpatialComputeRegion(Operation* op) {
return op->getParentOfType<spatial::SpatCompute>() || op->getParentOfType<spatial::SpatComputeBatch>();
}
static Value createTransposeInit(Value input,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
SmallVector<OpFoldResult> sizes;
sizes.reserve(resultType.getRank());
for (auto [resultDim, sourceDim] : llvm::zip_equal(resultType.getShape(), permutation)) {
if (!ShapedType::isDynamic(resultDim)) {
sizes.push_back(rewriter.getIndexAttr(resultDim));
continue;
}
sizes.push_back(tensor::DimOp::create(rewriter, loc, input, sourceDim).getResult());
}
return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult();
}
static FailureOr<Value> materializeTransposedConstant(Value input,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
auto denseAttr = getHostConstDenseElementsAttr(input);
if (!denseAttr)
return failure();
auto inputType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!inputType || !inputType.hasStaticShape() || !resultType.hasStaticShape()
|| inputType.getRank() != resultType.getRank()
|| static_cast<int64_t>(permutation.size()) != inputType.getRank()) {
return failure();
}
if (denseAttr.isSplat())
return getOrCreateConstant(rewriter,
rewriter.getInsertionBlock()->getParentOp(),
DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>()),
resultType);
SmallVector<Attribute> inputValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> resultValues(inputValues.size());
SmallVector<int64_t> inputStrides = computeRowMajorStrides(inputType.getShape());
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
SmallVector<int64_t> inputIndices(inputType.getRank(), 0);
for (auto [linearIndex, value] : llvm::enumerate(inputValues)) {
int64_t remaining = static_cast<int64_t>(linearIndex);
for (int64_t dim = 0; dim < inputType.getRank(); ++dim) {
inputIndices[dim] = inputStrides.empty() ? 0 : remaining / inputStrides[dim];
remaining = inputStrides.empty() ? 0 : remaining % inputStrides[dim];
}
int64_t resultLinearIndex = 0;
for (int64_t dim = 0; dim < resultType.getRank(); ++dim)
resultLinearIndex += inputIndices[permutation[dim]] * resultStrides[dim];
resultValues[resultLinearIndex] = value;
}
return getOrCreateConstant(rewriter,
rewriter.getInsertionBlock()->getParentOp(),
DenseElementsAttr::get(resultType, resultValues),
resultType);
}
struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXTransposeOp transposeOp,
ONNXTransposeOpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto inputType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
auto resultType = dyn_cast<RankedTensorType>(transposeOp.getResult().getType());
if (!inputType || !resultType)
return failure();
auto permutation = getTransposePermutationChecked(transposeOp.getPermAttr(), inputType.getRank());
if (failed(permutation))
return failure();
if (isCompileTimeComputable(adaptor.getData())) {
auto constantTranspose =
materializeTransposedConstant(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
if (succeeded(constantTranspose)) {
rewriter.replaceOp(transposeOp, *constantTranspose);
return success();
}
}
auto buildTranspose = [&](Value input) -> Value {
Value init = createTransposeInit(input, resultType, *permutation, rewriter, transposeOp.getLoc());
return linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), input, init, *permutation).getResult()[0];
};
if (isInsideSpatialComputeRegion(transposeOp.getOperation())) {
rewriter.replaceOp(transposeOp, buildTranspose(adaptor.getData()));
return success();
}
auto computeOp = createSpatCompute<1>(
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {adaptor.getData()}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), buildTranspose(input));
});
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
return success();
}
};
} // namespace
void populateTransposePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<TransposeToLinalgTranspose>(ctx);
}
} // namespace onnx_mlir
@@ -1,265 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isWeightMaterializationHelperUser(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(op);
}
static bool canPromoteInputBlockArgument(BlockArgument arg) {
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
}
static bool isDirectConstantValue(Value value) {
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
}
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override {
if (batchOp.getLaneCount() != 1)
return rewriter.notifyMatchFailure(batchOp, "requires a single lane");
auto loc = batchOp.getLoc();
rewriter.setInsertionPoint(batchOp);
auto computeOp =
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
computeOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
Block& templateBlock = batchOp.getBody().front();
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.reserve(templateBlock.getNumArguments());
blockArgLocs.reserve(templateBlock.getNumArguments());
for (BlockArgument arg : templateBlock.getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(loc);
}
auto* newBlock =
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper;
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
mapper.map(oldArg, newArg);
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : templateBlock)
rewriter.clone(op, mapper);
batchOp->replaceAllUsesWith(computeOp->getResults());
rewriter.eraseOp(batchOp);
return success();
}
};
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
bool needsRewrite = false;
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= oldBlock.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute =
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(oldArg, *clonedValue);
}
for (Operation& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
};
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
bool needsRewrite = false;
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= oldBlock.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute =
spatial::SpatComputeBatch::create(rewriter,
compute.getLoc(),
compute.getResultTypes(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
newWeights,
newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
mapper.map(oldArg, *clonedValue);
}
for (Operation& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
};
} // namespace
void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<FoldSingleLaneComputeBatchPattern>(ctx);
}
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
}
void annotateWeightsConstants(func::FuncOp funcOp) {
funcOp.walk([&](arith::ConstantOp constantOp) {
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
markWeightAlways(constantOp);
});
}
} // namespace onnx_mlir
@@ -1,14 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h"
namespace onnx_mlir {
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -1,10 +0,0 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnx_mlir {
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir
@@ -202,6 +202,7 @@ void SpatialToGraphvizPass::runOnOperation() {
auto entryFunc = getPimEntryFunc(module); auto entryFunc = getPimEntryFunc(module);
if (failed(entryFunc)) { if (failed(entryFunc)) {
module.emitError("failed to locate the PIM entry function for Spatial graph visualization");
signalPassFailure(); signalPassFailure();
return; return;
} }
@@ -1,11 +1,16 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h" #include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp" #include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp" #include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -15,85 +20,131 @@ using namespace onnx_mlir::pim;
namespace onnx_mlir { namespace onnx_mlir {
namespace { namespace {
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); } static bool isUsedOnlyAsExplicitHostOperand(Value value) {
return !value.use_empty() && llvm::all_of(value.getUses(), [](OpOperand& use) {
return isExplicitDevToHostTargetOperand(use.getOwner(), use.getOperandNumber());
});
}
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) { static FailureOr<SmallVector<int32_t>> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp,
size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName)) if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end()); return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
SmallVector<int32_t> coreIds; SmallVector<int32_t> coreIds;
coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount())); coreIds.reserve(static_cast<size_t>(computeBatchOp.getLaneCount()));
for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) for (uint32_t lane = 0; lane < computeBatchOp.getLaneCount(); ++lane) {
coreIds.push_back(static_cast<int32_t>(fallbackCoreId++)); auto checkedCoreId =
pim::checkedI32(static_cast<uint64_t>(fallbackCoreId), computeBatchOp, "fallback spatial compute_batch core id");
if (failed(checkedCoreId))
return failure();
coreIds.push_back(*checkedCoreId);
++fallbackCoreId;
}
return coreIds; return coreIds;
} }
static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp, static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
IRMapping& mapper, if (!result.hasOneUse())
IRRewriter& rewriter) { return failure();
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size());
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
pim::PimSendTensorBatchOp::create(rewriter, auto returnOp = dyn_cast<func::ReturnOp>(*result.getUsers().begin());
sendTensorBatchOp.getLoc(), if (!returnOp)
mapper.lookup(sendTensorBatchOp.getInput()), return failure();
rewriter.getDenseI32ArrayAttr(targetCoreIds)); return result.getUses().begin()->getOperandNumber();
} }
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp, static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
IRMapping& mapper, if (scale == 1)
IRRewriter& rewriter) { return base;
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size());
for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds())
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType()); auto scaleValue = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scale);
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType); return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
Value received = pim::PimReceiveTensorBatchOp::create(rewriter, }
receiveTensorBatchOp.getLoc(),
outputBuffer.getType(), static Value createHostTargetOffset(IRRewriter& rewriter,
outputBuffer, tensor::ParallelInsertSliceOp insertSlice,
rewriter.getDenseI32ArrayAttr(sourceCoreIds)) ShapedType destinationType,
.getOutput(); IRMapping& mapper) {
mapper.map(receiveTensorBatchOp.getOutput(), received); int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
Value totalOffset;
Location loc = insertSlice.getLoc();
for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) {
int64_t scale = strides[dim] * elementBytes;
Value scaledOffset;
if (auto attr = dyn_cast<Attribute>(offset)) {
auto intAttr = dyn_cast<IntegerAttr>(attr);
assert(intAttr && "expected integer offset attribute");
scaledOffset =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), intAttr.getInt() * scale);
}
else {
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
}
totalOffset =
totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult() : scaledOffset;
}
if (!totalOffset)
totalOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
return totalOffset;
} }
} // namespace } // namespace
LogicalResult LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) { IRRewriter& rewriter) {
if (computeBatchOp.getNumResults() != 0)
return computeBatchOp.emitOpError(
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
Location loc = computeBatchOp.getLoc(); Location loc = computeBatchOp.getLoc();
Block& oldBlock = computeBatchOp.getBody().front(); Block& oldBlock = computeBatchOp.getBody().front();
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator()); auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
if (oldYield.getNumOperands() != 0) auto inParallelOp = dyn_cast<spatial::SpatInParallelOp>(oldBlock.getTerminator());
return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield"); if (computeBatchOp.getNumResults() == 0) {
if (!oldYield || oldYield.getNumOperands() != 0)
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield");
}
else if (!inParallelOp) {
return computeBatchOp.emitOpError(
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
}
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId); auto coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
if (failed(coreIds))
return failure();
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end()); SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
SmallVector<Value> batchInputs; SmallVector<Value> batchInputs;
if (!computeBatchOp.getInputs().empty()) if (!computeBatchOp.getInputs().empty())
batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end()); batchInputs.append(computeBatchOp.getInputs().begin(), computeBatchOp.getInputs().end());
rewriter.setInsertionPointAfter(computeBatchOp); rewriter.setInsertionPointAfter(computeBatchOp);
auto coreBatchOp = pim::PimCoreBatchOp::create(rewriter, auto laneCountAttr = pim::getCheckedI32Attr(
loc, rewriter, computeBatchOp, static_cast<uint64_t>(computeBatchOp.getLaneCount()), "pim core_batch lane count");
rewriter.getI32IntegerAttr(computeBatchOp.getLaneCount()), if (failed(laneCountAttr))
ValueRange(batchWeights), return failure();
ValueRange(batchInputs)); auto coreBatchOp =
pim::PimCoreBatchOp::create(rewriter, loc, *laneCountAttr, ValueRange(batchWeights), ValueRange(batchInputs));
coreBatchOp.getProperties().setOperandSegmentSizes( coreBatchOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())}); {static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds)); coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(*coreIds));
SmallVector<unsigned> returnOperandIndices;
if (computeBatchOp.getNumResults() != 0) {
returnOperandIndices.resize(computeBatchOp.getNumResults());
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
if (failed(returnOperandIndex))
return computeBatchOp.emitOpError(
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
returnOperandIndices[resultIndex] = *returnOperandIndex;
}
}
SmallVector<Type> blockArgTypes; SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs; SmallVector<Location> blockArgLocs;
for (BlockArgument arg : oldBlock.getArguments()) { unsigned inputArgLimit = 1 + computeBatchOp.getWeights().size() + computeBatchOp.getInputs().size();
for (BlockArgument arg : oldBlock.getArguments().take_front(inputArgLimit)) {
blockArgTypes.push_back(arg.getType()); blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(arg.getLoc()); blockArgLocs.push_back(arg.getLoc());
} }
@@ -102,38 +153,41 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
IRMapping mapper; IRMapping mapper;
rewriter.setInsertionPointToStart(newBlock); rewriter.setInsertionPointToStart(newBlock);
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) { auto oldLaneArg = computeBatchOp.getLaneArgument();
if (!oldLaneArg)
return computeBatchOp.emitOpError("expected compute_batch lane block argument before lowering");
mapper.map(*oldLaneArg, coreBatchOp.getLaneArgument());
for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex) {
auto oldWeightArg = computeBatchOp.getWeightArgument(weightIndex);
if (!oldWeightArg)
return computeBatchOp.emitOpError("expected compute_batch weight block arguments before lowering");
mapper.map(*oldWeightArg, coreBatchOp.getWeightArgument(weightIndex));
}
for (unsigned inputIndex = 0; inputIndex < computeBatchOp.getInputs().size(); ++inputIndex) {
auto oldArg = computeBatchOp.getInputArgument(inputIndex);
if (!oldArg)
return computeBatchOp.emitOpError("expected compute_batch input block arguments before lowering");
BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex);
auto newArgType = cast<ShapedType>(newArg.getType()); auto newArgType = cast<ShapedType>(newArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
loc, auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), newArg);
outputBuffer.getType(), if (failed(sizeAttr))
outputBuffer, return failure();
newArg, auto copied = pim::PimMemCopyHostToDevOp::create(
rewriter.getI32IntegerAttr(0), rewriter, loc, outputBuffer.getType(), zeroOffset, zeroOffset, outputBuffer, newArg, *sizeAttr)
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, newArg))
.getOutput(); .getOutput();
mapper.map(oldArg, copied); mapper.map(*oldArg, copied);
} }
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value { SmallVector<Value> hostOutputTensors(returnOperandIndices.size());
if (auto mapped = mapper.lookupOrNull(capturedTensor)) auto getOrCreateHostOutputTensor = [&](unsigned resultIndex, Location resultLoc) -> Value {
return mapped; Value& hostOutputTensor = hostOutputTensors[resultIndex];
if (hostOutputTensor)
return hostOutputTensor;
auto capturedType = cast<ShapedType>(capturedTensor.getType()); hostOutputTensor = outputTensors[returnOperandIndices[resultIndex]](rewriter, resultLoc);
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType); return hostOutputTensor;
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
capturedTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, capturedTensor))
.getOutput();
mapper.map(capturedTensor, copied);
return copied;
}; };
rewriter.setInsertionPointToEnd(newBlock); rewriter.setInsertionPointToEnd(newBlock);
@@ -141,36 +195,40 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
if (isa<spatial::SpatYieldOp>(op)) if (isa<spatial::SpatYieldOp>(op))
continue; continue;
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) { if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
pim::PimSendBatchOp::create(rewriter, auto firstOutputArg = computeBatchOp.getOutputArgument(0);
loc, if (!firstOutputArg)
mapper.lookup(sendBatchOp.getInput()), return computeBatchOp.emitOpError("expected compute_batch output block arguments before lowering");
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())), for (Operation& nestedOp : parallelOp.getRegion().front()) {
sendBatchOp.getTargetCoreIdsAttr()); auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&nestedOp);
continue; if (!insertSlice)
} return parallelOp.emitOpError("expected only tensor.parallel_insert_slice in spat.in_parallel");
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) { auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter); if (!outputArg || outputArg.getOwner() != &oldBlock)
continue; return insertSlice.emitOpError("expected compute_batch output block argument destination");
}
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) { unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType()); if (resultIndex >= returnOperandIndices.size())
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType); return insertSlice.emitOpError("result index out of range while lowering host batch output");
auto received = pim::PimReceiveBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
receiveBatchOp.getSourceCoreIdsAttr())
.getOutput();
mapper.map(receiveBatchOp.getOutput(), received);
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) { Value mappedSource = mapper.lookup(insertSlice.getSource());
lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter); Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), mappedSource);
if (failed(sizeAttr))
return failure();
pim::PimMemCopyDevToHostOp::create(rewriter,
insertSlice.getLoc(),
hostTarget.getType(),
hostTargetOffset,
zeroOffset,
hostTarget,
mappedSource,
*sizeAttr);
}
continue; continue;
} }
@@ -178,31 +236,37 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) { if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
Operation* cloned = rewriter.clone(op, mapper); Operation* cloned = rewriter.clone(op, mapper);
auto clonedTensor = cloned->getResult(0); auto clonedTensor = cloned->getResult(0);
if (isUsedOnlyAsExplicitHostOperand(toTensorOp.getResult())) {
mapper.map(toTensorOp.getResult(), clonedTensor);
continue;
}
auto clonedType = cast<ShapedType>(clonedTensor.getType()); auto clonedType = cast<ShapedType>(clonedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType); auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter, Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
loc, auto sizeAttr = getTensorSizeInBytesAttr(rewriter, coreBatchOp.getOperation(), clonedTensor);
outputBuffer.getType(), if (failed(sizeAttr))
outputBuffer, return failure();
clonedTensor, auto copied =
rewriter.getI32IntegerAttr(0), pim::PimMemCopyHostToDevOp::create(
rewriter.getI32IntegerAttr(0), rewriter, loc, outputBuffer.getType(), zeroOffset, zeroOffset, outputBuffer, clonedTensor, *sizeAttr)
getTensorSizeInBytesAttr(rewriter, clonedTensor)) .getOutput();
.getOutput();
mapper.map(toTensorOp.getResult(), copied); mapper.map(toTensorOp.getResult(), copied);
continue; continue;
} }
} }
for (Value operand : op.getOperands()) { for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand)) if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
continue; continue;
if (isExplicitDevToHostTargetOperand(&op, operandIndex))
continue;
Operation* definingOp = operand.getDefiningOp(); Operation* definingOp = operand.getDefiningOp();
if (definingOp && definingOp->getBlock() == &oldBlock) if (definingOp && definingOp->getBlock() == &oldBlock)
continue; continue;
materializeCapturedTensor(operand); return computeBatchOp.emitOpError(
"expected external tensor communication to be materialized in Spatial before batch lowering");
} }
Operation* cloned = rewriter.clone(op, mapper); Operation* cloned = rewriter.clone(op, mapper);
@@ -1,10 +0,0 @@
#pragma once
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
namespace onnx_mlir {
mlir::LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -3,17 +3,17 @@ mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(SpatialToPimIncGen) add_public_tablegen_target(SpatialToPimIncGen)
add_pim_library(OMSpatialToPim add_pim_library(OMSpatialToPim
Patterns.cpp
SpatialToPimPass.cpp SpatialToPimPass.cpp
BatchCoreLoweringPatterns.cpp BatchCoreLoweringPatterns.cpp
ChannelLoweringPatterns.cpp
Cleanup.cpp
Common.cpp Common.cpp
ComputeLikeRegionUtils.cpp ComputeLikeRegionUtils.cpp
CoreLoweringPatterns.cpp CoreLoweringPatterns.cpp
GlobalTensorMaterialization.cpp
PhaseVerification.cpp
ReturnPathNormalization.cpp ReturnPathNormalization.cpp
TensorPackingPatterns.cpp Patterns/ChannelLowering.cpp
Patterns/GlobalTensorMaterialization.cpp
Patterns/TensorPacking.cpp
Patterns/Transpose.cpp
EXCLUDE_FROM_OM_LIBS EXCLUDE_FROM_OM_LIBS
@@ -21,7 +21,10 @@ add_pim_library(OMSpatialToPim
SpatialToPimIncGen SpatialToPimIncGen
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRLinalgDialect
MLIRSCFDialect MLIRSCFDialect
MLIRSCFUtils
MLIRTransformUtils
MLIRTosaDialect MLIRTosaDialect
OMCompilerOptions OMCompilerOptions
OMPimCommon OMPimCommon
@@ -1,9 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir

Some files were not shown because too many files have changed in this diff Show More