156 Commits

Author SHA1 Message Date
ilgeco 852bef7605 ReduceMean + resnet
Validate Operations / validate-operations (push) Waiting to run
2026-06-10 14:30:10 +02:00
ilgeco 237654dadf Fix direct import
Validate Operations / validate-operations (push) Waiting to run
2026-06-10 12:14:20 +02:00
ilgeco 6d69600bc1 Yolo Image Validator + new accept rule
Validate Operations / validate-operations (push) Waiting to run
2026-06-10 11:59:43 +02:00
NiccoloN aec80529ca much faster MaterializeMergeSchedule.cpp
Validate Operations / validate-operations (push) Has been cancelled
2026-06-05 18:22:59 +02:00
ilgeco 8ddbbcecfa Added support for SliceOp
Validate Operations / validate-operations (push) Has been cancelled
2026-06-05 17:36:51 +02:00
ilgeco 90c4339808 SpatialSubOp
Validate Operations / validate-operations (push) Has been cancelled
2026-06-05 17:12:16 +02:00
ilgeco 08870de1a6 Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone 2026-06-05 16:43:50 +02:00
NiccoloN a34ac223c0 fix remaining failing tests
Validate Operations / validate-operations (push) Has been cancelled
remove unsupported tests
2026-06-05 15:27:11 +02:00
NiccoloN 0fa10b4074 better Conv.cpp and fixed broken conv op validation test
Validate Operations / validate-operations (push) Has been cancelled
2026-06-05 13:35:27 +02:00
NiccoloN e166ff7e1d better AGENTS.md
Validate Operations / validate-operations (push) Has been cancelled
2026-06-05 11:36:01 +02:00
ilgeco a70a8f77cf Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone 2026-06-05 10:20:09 +02:00
ilgeco 800c0c4316 Python peft and new summary report 2026-06-05 10:20:02 +02:00
NiccoloN 1e9e61f5a9 remove useless MaterializeHostConstantsPass.cpp and fix lowering before instead
Validate Operations / validate-operations (push) Has been cancelled
avoid spammy pim codegen diagnostics
2026-06-05 10:06:28 +02:00
ilgeco 27410207c4 New corner case test
Validate Operations / validate-operations (push) Has been cancelled
2026-06-04 16:00:48 +02:00
NiccoloN cbc9808229 more generalized MaterializeMergeSchedule.cpp for better memory usage after materialization
Validate Operations / validate-operations (push) Has been cancelled
2026-06-04 12:44:57 +02:00
NiccoloN 69021d56aa automatic code reformat
Validate Operations / validate-operations (push) Has been cancelled
2026-06-03 19:43:56 +02:00
NiccoloN dc5edd032c Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone 2026-06-03 19:40:53 +02:00
NiccoloN e33f517221 faster scheduling: split batches into numCores tasks before scheduling instead of numLanes tasks 2026-06-03 19:40:34 +02:00
ilgeco f94b3d1020 Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone
Validate Operations / validate-operations (push) Has been cancelled
2026-06-03 18:15:33 +02:00
ilgeco 20cf40c9ba Memory Liveness 2026-06-03 18:15:30 +02:00
NiccoloN 37a59054a5 better loop compaction in MaterializeMergeSchedule.cpp
Validate Operations / validate-operations (push) Has been cancelled
2026-06-03 16:01:19 +02:00
ilgeco 2a8faf9c6b Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone 2026-06-03 13:49:42 +02:00
ilgeco 01b9d03fc6 Early warning on memory address 2026-06-03 13:49:39 +02:00
NiccoloN 501e6c76f3 better memory report
Validate Operations / validate-operations (push) Has been cancelled
capped vector allocations at u32::MAX in rust simulator
2026-06-03 13:48:42 +02:00
ilgeco 3c2667f11e Fix memory bug
Validate Operations / validate-operations (push) Has been cancelled
2026-06-03 12:59:58 +02:00
NiccoloN 0a5e73c3ea better transpose pattern and cleanup
Validate Operations / validate-operations (push) Has been cancelled
2026-06-03 12:26:31 +02:00
NiccoloN 636310d0cb add shared loop creation helpers
Validate Operations / validate-operations (push) Has been cancelled
add shared checked arithmetic helpers
refactor pim passes into Pim/Transforms
more robust memory coalescing pass
2026-06-01 16:49:06 +02:00
NiccoloN 356be6ccc2 uniquify constants produced by affine lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-06-01 10:52:25 +02:00
NiccoloN b678e55d3c compact memory contiguity with for loops
Validate Operations / validate-operations (push) Has been cancelled
2026-05-31 18:47:59 +02:00
NiccoloN ab63498f3f normalize affine arithmetic helpers
Validate Operations / validate-operations (push) Has been cancelled
2026-05-30 16:37:28 +02:00
NiccoloN 7c3943bd06 Merge remote-tracking branch 'origin/refactorone' into refactorone
Validate Operations / validate-operations (push) Has been cancelled
# Conflicts:
#	src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp
2026-05-30 16:12:42 +02:00
NiccoloN c0238c0d06 fix high memory usage caused by MaterializeMergeSchedule.cpp with more robust code 2026-05-30 16:12:06 +02:00
NiccoloN ff36729140 centralize logic for materializing contiguous memory into bufferization
fix codegen symlinks overwrite
remove deprecated pim memcp_hd_batch op
2026-05-30 16:09:58 +02:00
NiccoloN cf93caecd5 centralize logic for materializing contiguous memory into bufferization
Validate Operations / validate-operations (push) Has been cancelled
fix codegen symlinks overwrite
remove deprecated pim memcp_hd_batch op
2026-05-30 15:54:24 +02:00
NiccoloN 2d5b03c08f automatic code reformat
Validate Operations / validate-operations (push) Has been cancelled
2026-05-29 19:21:37 +02:00
NiccoloN a41f694cf0 batched matmul pattern
Validate Operations / validate-operations (push) Has been cancelled
add conv helpers
new validation tests for matmul
2026-05-29 19:09:48 +02:00
NiccoloN 8bb0babf1b finish helper refactoring
Validate Operations / validate-operations (push) Has been cancelled
use uniqued constant helpers everywhere
materialize transposed constants directly
2026-05-29 17:05:45 +02:00
ilgeco 819d8af0f7 Refactor + ReduceMean batched
Validate Operations / validate-operations (push) Has been cancelled
2026-05-29 15:57:13 +02:00
ilgeco 832bd7f1f7 Transpose and Refactor of Patterns
Validate Operations / validate-operations (push) Has been cancelled
2026-05-29 13:23:31 +02:00
ilgeco 82b44a6387 New Onnx test gemm model 2026-05-29 11:41:30 +02:00
ilgeco 7fcc765d6e New Onnx Test model 2026-05-29 11:37:17 +02:00
ilgeco f34698a2b6 Validate new option for compile only
Validate Operations / validate-operations (push) Has been cancelled
2026-05-28 22:59:26 +02:00
ilgeco 1ab489fe0a Dynamic gemm/conv 2026-05-28 18:00:14 +02:00
ilgeco cbf7b235f1 pim-simulator now support usize addresses
Validate Operations / validate-operations (push) Has been cancelled
2026-05-28 17:03:19 +02:00
NiccoloN 00414dd1d9 add verification of communication invariants at the end of spatial
Validate Operations / validate-operations (push) Has been cancelled
remove dead logic
2026-05-27 19:17:48 +02:00
NiccoloN 783dffe553 fix scheduling cost model
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 17:14:19 +02:00
NiccoloN 874a2f53e6 automatic code reformat
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 16:39:56 +02:00
NiccoloN 4bdaa57656 simplify affine maps to constants where possible
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 16:39:27 +02:00
NiccoloN 1a5d7d2a3f fix bufferization and weight emission after new gemm patterns
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 16:15:10 +02:00
ilgeco 013ae0ac2a Update README and AGENTS
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 15:09:30 +02:00
ilgeco c6b02af7a9 Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone 2026-05-27 14:32:51 +02:00
ilgeco d2048bd394 Add to gitignore 2026-05-27 14:32:47 +02:00
NiccoloN 158f0f0c54 update AGENTS.md
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 14:32:04 +02:00
NiccoloN 532cac8246 commit AGENTS.md
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 14:07:34 +02:00
NiccoloN d609e84054 teh only weight (WIP)
Validate Operations / validate-operations (push) Has been cancelled
2026-05-26 18:42:14 +02:00
NiccoloN addfc8a86e remove other dead logic
Validate Operations / validate-operations (push) Has been cancelled
2026-05-25 21:22:08 +02:00
NiccoloN 0f240af271 cleanup unused channel operations and related logic
Validate Operations / validate-operations (push) Has been cancelled
2026-05-25 20:58:51 +02:00
ilgeco bdc4ca33f3 No extract no more
Validate Operations / validate-operations (push) Has been cancelled
2026-05-25 18:19:43 +02:00
ilgeco b79c333c6c Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone 2026-05-25 15:44:40 +02:00
ilgeco eea9261c7b Bye Bye DCP 2026-05-25 15:44:30 +02:00
NiccoloN e8a08f6dd0 faster pim VerificationPass.cpp and pim code emission
Validate Operations / validate-operations (push) Has been cancelled
2026-05-25 15:24:12 +02:00
NiccoloN 4855a2e105 add verification of static weights in spatial
Validate Operations / validate-operations (push) Has been cancelled
2026-05-24 12:00:42 +02:00
NiccoloN 3a7a832198 MaterializeMergeSchedule.cpp fix for yolo11_depth_18 2026-05-24 11:54:00 +02:00
NiccoloN 48ca6bd28d speed fix with a simple cache
Validate Operations / validate-operations (push) Has been cancelled
2026-05-24 10:52:28 +02:00
NiccoloN f595cc6ffd fix high memory usage in IR 2026-05-24 10:41:47 +02:00
NiccoloN c734f1b37e better MaterializeMergeSchedule.cpp that emits much more compact IR
Validate Operations / validate-operations (push) Has been cancelled
add support for other constant-time arith ops in codegen
2026-05-24 10:10:24 +02:00
NiccoloN b79ce8eeaa use affine dialect to express simple constant progressions
Validate Operations / validate-operations (push) Has been cancelled
run dce at the end of MaterializeMergeSchedule to get rid of unused constants
2026-05-23 14:25:34 +02:00
NiccoloN 76a37e198f better MaterializeMergeSchedule.cpp with both send and receive compaction in for loops
Validate Operations / validate-operations (push) Has been cancelled
2026-05-23 11:17:36 +02:00
NiccoloN 7f3c7464b4 update cost model of batch lanes to consider only a slice of the shared batch input
Validate Operations / validate-operations (push) Has been cancelled
2026-05-22 22:16:19 +02:00
NiccoloN c77ffa9c56 better MaterializeMergeSchedule.cpp with %lane indexed batch computes
support for tensors of index values
2026-05-22 21:52:28 +02:00
NiccoloN 495186503c fix cmake magic once again 2026-05-22 19:21:56 +02:00
NiccoloN 2c1da813b5 fix much stuff 2026-05-22 18:53:38 +02:00
NiccoloN 8337a11ce9 automatic code reformat 2026-05-22 15:23:48 +02:00
ilgeco d136136d22 Fix add of input in random order for compute_batch
Validate Operations / validate-operations (push) Has been cancelled
2026-05-22 15:21:02 +02:00
NiccoloN 074eb183c7 saner SpatialToPimPass architecture
Validate Operations / validate-operations (push) Has been cancelled
2026-05-22 07:27:54 +02:00
NiccoloN 43ed3914b8 better MaterializeMergeSchedule.cpp (something still broken downstream)
Validate Operations / validate-operations (push) Has been cancelled
2026-05-22 06:56:39 +02:00
ilgeco 6aaf1c0870 Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone
Validate Operations / validate-operations (push) Has been cancelled
2026-05-21 14:44:19 +02:00
ilgeco fe35b3ed43 Equivalent Class but broken 2026-05-21 14:43:59 +02:00
NiccoloN 90a9339686 better cmake to keep IDEs analyses happy
Validate Operations / validate-operations (push) Has been cancelled
2026-05-21 14:13:54 +02:00
NiccoloN a50e77ff38 refactorone
Validate Operations / validate-operations (push) Has been cancelled
2026-05-20 19:06:41 +02:00
NiccoloN f56c4159b5 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor 2026-05-19 15:01:26 +02:00
ilgeco 5637c861b4 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor
Validate Operations / validate-operations (push) Has been cancelled
2026-05-19 15:00:11 +02:00
ilgeco 94157a8404 Very big timeout 2026-05-19 14:53:34 +02:00
ilgeco 68a3521978 Perft topological fix 2026-05-19 14:52:54 +02:00
NiccoloN a103ba328b remove dead logic 2026-05-19 12:23:01 +02:00
NiccoloN e263e05f56 remove dead logic
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 18:32:40 +02:00
ilgeco 34c29fdec4 Materialize modification
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 17:22:13 +02:00
ilgeco aa088e2ba5 Verify fix
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 17:20:40 +02:00
NiccoloN 2836e759ab remove useless file
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 14:51:03 +02:00
NiccoloN 8071ebab0b faster refactored merge pass
Validate Operations / validate-operations (push) Has been cancelled
2026-05-18 14:50:19 +02:00
NiccoloN f1602c0550 add peft scheduling
Validate Operations / validate-operations (push) Has been cancelled
better deadlock report by pim simulator
2026-05-18 12:09:27 +02:00
NiccoloN de0a2f4561 remove useless guard in gemm lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-05-15 18:22:13 +02:00
NiccoloN 1c4a5bde76 compact softmax op lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-05-15 18:14:59 +02:00
NiccoloN 78242e2887 compact resize op lowering
Validate Operations / validate-operations (push) Has been cancelled
2026-05-15 17:36:12 +02:00
NiccoloN fe244d5aa1 new ops tests for matmul, grouped conv, concat and reshape
Validate Operations / validate-operations (push) Has been cancelled
related fixes
2026-05-14 15:54:06 +02:00
NiccoloN d09e76c8f9 fix matmul rewriting/lowering
Validate Operations / validate-operations (push) Has been cancelled
fix reshape lowering
add support for grouped-convolution lowering
quieter verifier with capped error messages
2026-05-14 14:09:30 +02:00
NiccoloN c5e608fa5b replace greedy pattern rewrites with partial conversions
Validate Operations / validate-operations (push) Has been cancelled
better failure messages
2026-05-14 11:48:16 +02:00
ilgeco 43f3ccdd21 new yolo nodes with 100% more statics
Validate Operations / validate-operations (push) Has been cancelled
2026-05-14 10:47:31 +02:00
NiccoloN 8d95c604a6 automatic code formatting
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 21:51:19 +02:00
NiccoloN 55eda487dc use seed in validate.py for deterministic tests 2026-05-13 21:49:36 +02:00
NiccoloN 061139aefb fix wrong send/receive reordering in post dcp merge instructions compaction 2026-05-13 21:48:49 +02:00
NiccoloN ea61540e08 fix failing validations after last commit
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 17:46:19 +02:00
NiccoloN 324178cba8 fix instructions explosion in pim host constant folding pass
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 17:31:05 +02:00
NiccoloN e71ba07cd5 fix pim-simulator stale tests
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 16:59:53 +02:00
NiccoloN 64a3805619 fix pim-simulator stale tests
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 16:59:43 +02:00
NiccoloN 9f9e7c0892 Merge remote-tracking branch 'origin/main'
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 16:38:33 +02:00
NiccoloN 03eab42971 remove host core generation
strip config.json emitted by raptor
add actual pimsim-nn configs in validation pimsim-configs
2026-05-13 16:31:01 +02:00
ilgeco c15aba5d96 pim-simulator removed useless comment
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 15:05:17 +02:00
ilgeco 4821e8a55e Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor 2026-05-13 15:03:20 +02:00
ilgeco 88bb223bb1 Fix multiple address bug 2026-05-13 14:15:41 +02:00
NiccoloN 623ee62a04 point pimsim-nn submodule to the HEAPLab fork
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 14:10:27 +02:00
ilgeco ad56888b0b Broken pim-sim commit 2026-05-13 13:32:09 +02:00
NiccoloN f993840641 update pimsim-nn submodule
Validate Operations / validate-operations (push) Has been cancelled
2026-05-13 11:39:47 +02:00
NiccoloN 0c7db55a24 binary pim code for reduced memory usage
Validate Operations / validate-operations (push) Has been cancelled
fast pim code emission
2026-05-13 11:15:54 +02:00
NiccoloN 41de3cb150 add memory coalescing pass
Validate Operations / validate-operations (push) Has been cancelled
better reports
refactor for more code-reuse and patter usage
fixes
2026-05-12 18:17:00 +02:00
NiccoloN 4f3570520c add pim.vmm verifier and fix vmm lowering
reuse code for subviews
2026-05-12 15:13:50 +02:00
NiccoloN 628dc630a4 compact syntax for spatial tensor ops
Validate Operations / validate-operations (push) Has been cancelled
better IR compaction after dcp merge
remove pim.mvm op
better memory report
2026-05-12 13:35:25 +02:00
NiccoloN 80a7298552 fix pool lowering
Validate Operations / validate-operations (push) Has been cancelled
better reports (dcp merge and memory)
2026-05-12 12:32:23 +02:00
ilgeco 8ad504fcdf Yolo splitted at conv boundary
Validate Operations / validate-operations (push) Has been cancelled
2026-05-12 11:33:15 +02:00
ilgeco e6f442c5d2 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor
Validate Operations / validate-operations (push) Has been cancelled
2026-05-12 10:43:01 +02:00
ilgeco f6b97b3813 Fix report memory 2026-05-12 10:42:38 +02:00
ilgeco 26317ea7d0 Shorter Memory Reporty 2026-05-12 10:38:35 +02:00
NiccoloN 909c4acfdd huge refactor for high RewritePatterns usage and less ad-hoc cpp code
Validate Operations / validate-operations (push) Has been cancelled
remove Spatial many ops in favor of tensor ops like in pim
2026-05-12 10:35:44 +02:00
ilgeco feaff820e1 pim-sim TraceTime + faer
Validate Operations / validate-operations (push) Has been cancelled
2026-05-11 18:19:30 +02:00
NiccoloN 1e279ae9bb minor fix
Validate Operations / validate-operations (push) Has been cancelled
2026-05-11 16:01:42 +02:00
NiccoloN 57f0cca8c0 remove duplicated code
Validate Operations / validate-operations (push) Has been cancelled
quieter validation scripts (with optional verbose flag)
2026-05-11 15:52:26 +02:00
NiccoloN 5ff364027b big cleanup: remove remaining pim many operations, simplify bufferization logic
Validate Operations / validate-operations (push) Has been cancelled
2026-05-11 14:38:13 +02:00
NiccoloN b1272d2283 fast pim bufferization using tensors
Validate Operations / validate-operations (push) Successful in 24m29s
2026-05-08 14:21:45 +02:00
NiccoloN 58e6587697 Merge remote-tracking branch 'origin/main' 2026-05-08 13:12:47 +02:00
NiccoloN f6c8cc4aa5 sightly better bufferization
minor fixes
2026-05-07 17:53:47 +02:00
ilgeco 566630b99a Removed SpatNopPattern
Validate Operations / validate-operations (push) Successful in 22m36s
2026-05-07 17:03:35 +02:00
ilgeco 74931ad75b Single Concat Fix 2026-05-07 16:47:01 +02:00
NiccoloN f2fe147961 compact pim IR
Validate Operations / validate-operations (push) Successful in 22m15s
2026-05-06 17:16:51 +02:00
ilgeco 7bb58e80de Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into main
Validate Operations / validate-operations (push) Successful in 24m25s
2026-05-06 12:25:29 +02:00
NiccoloN b2dc9c38b6 better spatial IR compaction with better custom syntax, scf.for and
Validate Operations / validate-operations (push) Has been cancelled
spat.map
2026-05-06 12:21:58 +02:00
ilgeco 3cb6a1abc5 Memory report 2026-05-06 10:47:04 +02:00
NiccoloN 285773fa55 rework actually broken dcp merge + compute re-batching (still to refine) 2026-05-04 19:30:40 +02:00
NiccoloN bdacb9871d fix dcp merge bug
Validate Operations / validate-operations (push) Failing after 15m54s
2026-05-04 15:58:14 +02:00
NiccoloN 5b9bb0c191 refactor spatial ops
Validate Operations / validate-operations (push) Successful in 24m55s
2026-05-04 14:19:30 +02:00
NiccoloN f789954ad7 Refactor ONNXToSpatial Common and diagnostics 2026-05-04 13:42:43 +02:00
ilgeco b6ba1e4fea Fix DCPTest using old constructor
Validate Operations / validate-operations (push) Successful in 24m15s
2026-05-04 10:58:51 +02:00
NiccoloN 717ad160cd Refactor PIM/Common (splitting in files, adding helpers, adding brief
Validate Operations / validate-operations (push) Failing after 18m36s
docs)
2026-05-04 09:20:43 +02:00
NiccoloN 905fa9f9a7 Merge remote changes
Validate Operations / validate-operations (push) Failing after 18m42s
2026-05-03 23:09:32 +02:00
NiccoloN 62b0a6e19d merge remote changes 2026-05-03 22:30:46 +02:00
NiccoloN b605585b1f compact spatial IR through different new operations and dedicated syntax
fast spatial node merging with batch operations
2026-05-03 14:14:14 +02:00
ilgeco 08b0fcd850 Parallel bufferization
Validate Operations / validate-operations (push) Successful in 21m49s
2026-04-30 11:48:17 +02:00
ilgeco 9dccc2c701 Translate global constant to symble 2026-04-28 12:42:01 +02:00
ilgeco 5c839e62c1 Func Input converted to symbol 2026-04-27 13:48:03 +02:00
NiccoloN 15e8edb9c4 better spat computes merging
Validate Operations / validate-operations (push) Successful in 21m14s
2026-04-25 19:24:09 +02:00
ilgeco 951baca106 Merge Node update fix comparison bug
Validate Operations / validate-operations (push) Successful in 20m21s
2026-04-23 19:52:16 +02:00
ilgeco fc5bccb487 Merge Node update status file
Validate Operations / validate-operations (push) Has started running
2026-04-23 19:42:56 +02:00
ilgeco 49dea15b95 DCP Merge status
Validate Operations / validate-operations (push) Successful in 22m29s
2026-04-23 18:40:33 +02:00
NiccoloN 5545b0f672 fix MatMul pattern non-contiguous extract_slices
Validate Operations / validate-operations (push) Successful in 22m31s
2026-04-23 14:44:30 +02:00
NiccoloN cff929a083 fix sigmoid implementation stability in pim-simulator
Validate Operations / validate-operations (push) Successful in 23m4s
2026-04-23 10:34:29 +02:00
NiccoloN 89b3501aa8 fix weightAlways attribute in spatial 2026-04-23 10:04:47 +02:00
NiccoloN 412ca957f6 multiple-output spat computes
Validate Operations / validate-operations (push) Successful in 22m38s
2026-04-23 09:28:57 +02:00
419 changed files with 34322 additions and 9510 deletions
+11 -2
View File
@@ -1,5 +1,14 @@
.zed
.idea
**/.vscode
.claude
AGENTS.md
build
.codex
CMakeUserPresets.json
build_*
compile.sh
pimcomp_utils/*
**/__*
+1 -1
View File
@@ -3,4 +3,4 @@
url = https://github.com/onnx/onnx-mlir.git
[submodule "backend-simulators/pim/pimsim-nn"]
path = backend-simulators/pim/pimsim-nn
url = https://github.com/wangxy-2000/pimsim-nn.git
url = https://github.com/HEAPLab/pimsim-nn.git
+210
View File
@@ -0,0 +1,210 @@
* Always read the full README.md before doing anything
* Build commands:
* `cmake --build ./build_release`
* `cmake --build ./build_debug`
* Never use `ninja` directly: it bypasses cmake's configuration and invalidates the build cache
* Always try the release build first before building with the debug version
* Use the debug build only when it is useful to obtain a clear stack trace with symbols, inspect names, place breakpoints, or test a small case interactively
* The debug build is very slow, so use it only on small fast tests such as operation validations, not on network validations
# Core engineering philosophy
* Clean architecture matters as much as making the immediate test pass
* Prefer fixes that preserve clear ownership boundaries, explicit invariants, and simple dataflow
* Do not stack compensating fixes on top of earlier mistakes. If the current approach is becoming messy, stop and explain why
* A correct fix should usually make the responsible producer, resolver, verifier, or lowering own the behavior directly
* Avoid late repair passes, defensive cleanup, or broad rewrites when a cleaner owner-side fix is possible
* Do not hide an upstream modeling bug by normalizing it later in the pipeline. Fix the producer when the producer owns the invariant
* Prefer patterns/rewrites for local IR canonicalization. Use module walks only when pass-level structural analysis genuinely requires them
* Prefer compact, structured designs over long case-by-case implementations
# Think before coding
* State assumptions explicitly before implementing when they affect the design
* If multiple interpretations exist, present them instead of silently choosing one
* If a simpler approach exists, say so and prefer it unless there is a clear reason not to
* If something is unclear, stop, name what is confusing, and ask
* If the requested or obvious approach would make the architecture worse, push back and propose a cleaner alternative
# Code changes
* Keep changes minimal and localized to the relevant parts of the code
* Preserve the existing naming conventions and coding style used in the surrounding code
* Keep code easy to read, well organized, and suitable for future extensibility
* A function must not exceed roughly 200/250 lines. If a change pushes a function beyond that, extract focused helpers
* Prefer clear naming and structure over comments. Add comments only when they materially improve clarity
* Do not rename symbols, move files, or restructure modules unless that is necessary for the requested change
* Avoid duplicate ad-hoc logic. If the same concept appears in multiple places, consider whether it deserves a shared helper/API
* When adding a helper or API, ask:
* Could this be useful to another component now
* Is another component already implementing the same idea differently
* Is this likely to be needed by a future adjacent component
* What is the narrowest useful abstraction
* What is the correct ownership level for this API
* If a shared API is justified, place it at the lowest clean layer that can be used by all relevant consumers without creating dependency cycles or leaking policy across layers
* If an existing component should use a newly introduced shared API, refactor that component in the same patch when doing so is directly related and reduces duplication
* Do not create broad frameworks just because a helper might someday be useful. Shared APIs should encode a real reusable concept, not speculative generality
* If the reusable abstraction is plausible but not clearly needed yet, keep the code local and mention the possible future extraction separately
# Avoid case-listing designs
* Avoid solving problems with large chains of `if`/`else`, switches, or repeated special cases that enumerate every possible situation
* Long case listings tend to overfit the current tests, grow the codebase, and hide the underlying abstraction
* When you see a growing list of special cases, stop and look for the shared concept, data model, interface, or normalization step that would make the cases collapse
* Prefer table-driven logic, traits/interfaces, small reusable predicates, structured dispatch, or producer-side normalization when they express the invariant more directly
* A few explicit cases are fine when the domain is genuinely small and closed
* If the list is likely to grow, refactor toward a cleaner and more compact design instead of adding another branch
* When keeping a case list is the pragmatic choice, explain why the domain is closed or why a broader abstraction would be premature
# Ownership and invariants
Before implementing, identify the owner of the behavior:
* A producer should emit IR/data that satisfies the contract of the next stage
* A lowering should make representation changes explicit and semantically correct
* A resolver should resolve existing structure without silently changing semantics
* A verifier should reject invalid states with bounded, actionable diagnostics
* Codegen should assume verified invariants and fail clearly if they are violated
When fixing a bug:
* State the invariant that was violated
* State which component should own that invariant
* Fix that component directly
* Avoid fixes that merely mask the violation later in the pipeline
* Add or preserve verification if the invariant is important enough to regress
# Refactor and API policy
You may propose or implement a refactor when:
* the local fix would duplicate logic
* the local fix would violate a layer boundary
* the bug exists because responsibility is assigned to the wrong component
* multiple components already implement ad-hoc variants of the same concept
* a shared helper/API would make the code smaller, clearer, and easier to maintain
* existing callers can be migrated cleanly without broad churn
* the current implementation is turning into a long list of special cases instead of a structured solution
When proposing or implementing a refactor:
* Explain what responsibility is being moved or shared
* Justify why the new location is the right ownership level
* Keep the API narrow and named after the concept or invariant it represents
* Migrate directly related existing users when that improves compactness and consistency
* Separate changes required for correctness from optional cleanup
* Avoid unrelated renames, formatting changes, or module moves
* Do not expand a justified refactor beyond directly related callers
Do not refactor when:
* the issue is truly local and a local fix is clearer
* the abstraction would have only one user and no clear adjacent use
* the abstraction would mix policies from different layers
* the refactor would affect unrelated behavior
* the refactor is mainly aesthetic
# Working style
* Infer style and conventions from the existing code before introducing new patterns
* When several implementation options are possible, prefer the simplest one that fits the current architecture and minimizes churn
* Push back when the requested or obvious fix would make the architecture worse
* If a cleaner fix requires a small refactor or shared helper/API, propose it explicitly and justify it
* Avoid broad refactors unless explicitly requested or clearly necessary for correctness and maintainability
* When tests fail, bucket failures by likely root cause and separate patch-related failures from pre-existing or out-of-scope failures
# Simplicity first
* Minimum code that solves the problem cleanly. Nothing speculative
* No features beyond what was asked
* No error handling for impossible scenarios
* If you write 200 lines and it could be 50, rewrite it
* Ask: “Would a senior engineer say this is overcomplicated?” If yes, simplify
* Prefer direct, explicit code over generic machinery unless the generic machinery clearly reduces duplication and preserves boundaries
# Fallbacks and defaults
* Avoid silent fallback behavior when the semantic category is unknown
* Do not treat “unknown” as “safe” unless the codebase already defines that convention
* If a value cannot be classified, either preserve the existing behavior deliberately or fail with a clear diagnostic
* When adding a fallback, state why it is semantically valid and what invariant makes it safe
# Surgical changes
* Touch only what you must
* Clean up only the mess introduced by your own change
* Do not “improve” adjacent code, comments, or formatting
* Match existing style, even if you would personally do it differently
* If you notice unrelated dead code, bad abstractions, or fragile design, mention it separately. Do not delete or rewrite it unless asked
* When your changes create orphans, remove imports, variables, functions, or files made unused by your change
* Every changed line should trace directly to the requested fix, a required cleanup, or a justified reuse/refactor decision
# Diagnostics and verification
* Use existing bounded diagnostic mechanisms for pass-level verification or codegen failures
* Do not emit unbounded repeated diagnostics from loops or parallel workers
* Diagnostics should identify the violated invariant and the relevant value/op when useful
* Verifiers should reject invalid states, not repair them
* Codegen should not compensate for invalid IR/data unless codegen is the owner of that invariant
* Do not make failing tests pass by weakening verifiers, assertions, or diagnostics unless the check itself is proven wrong
* If a check is too strict, explain the valid case it rejects and update the invariant accordingly
* Prefer fixing invalid IR/data producers over relaxing consumers
* If adding diagnostics only for debugging, remove them or cap them before finalizing
# Temporary debugging code
* Temporary diagnostics, dumps, assertions, and debug-only helpers must be removed or intentionally converted into bounded permanent diagnostics before finalizing
* If debug instrumentation remains, explain why it is useful as permanent infrastructure
* Do not leave noisy validation output behind
# Performance awareness
* Avoid algorithmic regressions in compiler passes, especially repeated full-module walks, repeated expensive analyses, or per-op recomputation inside nested loops
* If a change adds a walk, cache, analysis, or structural traversal, justify why it is needed
* For hot paths, prefer preserving existing asymptotic behavior unless a better structure is part of the requested change
* If performance may change, mention the expected impact and suggest a targeted timing check
# Goal-driven execution
For multi-step tasks, state a brief plan:
1. [Step] → verify: [check]
2. [Step] → verify: [check]
3. [Step] → verify: [check]
Define success criteria before implementing:
* For bug fixes, success means reproducing or identifying the failure, fixing the responsible owner, and verifying the targeted case
* For refactors, success means preserving behavior while making ownership, reuse, or structure cleaner
* For validation changes, success means checking both valid and invalid cases when applicable
Transform tasks into verifiable goals:
* “Fix the bug” → identify the invariant, reproduce the failure, fix the owner, verify the targeted case
* “Add validation” → write or identify tests for invalid inputs, then make them pass/fail as expected
* “Refactor X” → preserve behavior before and after, then run relevant tests
# Final self-review
Before reporting completion, check:
* Did I fix the owner of the invariant rather than masking the issue downstream
* Did I avoid broad case lists and ad-hoc special handling
* Did I introduce a helper/API only at the right ownership level
* Did I migrate directly related duplicate logic when doing so improves compactness
* Did I avoid weakening verifiers or assertions unnecessarily
* Did I remove temporary debugging code or make it bounded and intentional
* Did I avoid unrelated formatting, renames, or cleanup
* Did I consider performance impact for added walks, analyses, caches, or repeated computations
* Did I run the required build/test commands
* Did I clearly report remaining failures or risks
When reporting back:
* Say what changed
* Say what was verified
* Say what remains
* When showing code in chat, make it easy to copy-paste into the codebase
* Keep outputs focused on the changed parts
* List bad practices, fragile assumptions, or cleaner alternatives separately
* If a change is intentionally pragmatic rather than architecturally ideal, say so and explain the tradeoff
+92 -24
View File
@@ -3,31 +3,99 @@ cmake_minimum_required(VERSION 3.20.0)
project(raptor)
# Add symlink to PIM as accelerator in onnx-mlir
function(raptor_ensure_symlink link_path target_path)
get_filename_component(link_parent "${link_path}" DIRECTORY)
# Materialize a CMake shim directory
function(raptor_write_external_cmake_shim shim_dir external_source_dir description)
get_filename_component(real_external_source_dir "${external_source_dir}" REALPATH)
file(RELATIVE_PATH relative_external_source_dir "${shim_dir}" "${real_external_source_dir}")
if(NOT EXISTS "${link_parent}")
message(FATAL_ERROR "Directory not found: ${link_parent}")
endif()
if(NOT EXISTS "${link_path}")
message(STATUS "Creating symlink ${link_path} -> ${target_path}")
file(CREATE_LINK
"${target_path}"
"${link_path}"
SYMBOLIC
if (NOT EXISTS "${real_external_source_dir}/CMakeLists.txt")
message(FATAL_ERROR
"External CMake source directory not found or missing CMakeLists.txt:\n"
" ${real_external_source_dir}"
)
endif()
endif ()
if (IS_SYMLINK "${shim_dir}")
message(STATUS "Removing old full-directory symlink: ${shim_dir}")
file(REMOVE "${shim_dir}")
endif ()
if (EXISTS "${shim_dir}" AND NOT IS_DIRECTORY "${shim_dir}")
message(FATAL_ERROR "Expected directory or absent path, got file: ${shim_dir}")
endif ()
file(MAKE_DIRECTORY "${shim_dir}")
set(shim_file "${shim_dir}/CMakeLists.txt")
set(shim_contents
"get_filename_component(raptor_external_source_dir
\"\${CMAKE_CURRENT_LIST_DIR}/${relative_external_source_dir}\"
REALPATH
)
add_subdirectory(
\"\${raptor_external_source_dir}\"
\"\${CMAKE_CURRENT_BINARY_DIR}/raptor-external\"
)
if (DEFINED PIM_ENABLED)
set(PIM_ENABLED \"\${PIM_ENABLED}\" PARENT_SCOPE)
endif ()
"
)
if (EXISTS "${shim_file}")
file(READ "${shim_file}" old_contents)
else ()
set(old_contents "")
endif ()
if (NOT old_contents STREQUAL shim_contents)
file(WRITE "${shim_file}" "${shim_contents}")
message(STATUS "Wrote CMake shim for ${description}: ${shim_file}")
else ()
message(STATUS "CMake shim already up to date for ${description}")
endif ()
# Mirror the external tree's first-level entries into the shim directory
# so legacy includes like src/Accelerators/PIM/Compiler/... keep working.
file(GLOB children RELATIVE "${real_external_source_dir}" "${real_external_source_dir}/*")
foreach (child IN LISTS children)
if (child STREQUAL "CMakeLists.txt")
continue()
endif ()
set(real_child "${real_external_source_dir}/${child}")
set(shim_child "${shim_dir}/${child}")
if (IS_SYMLINK "${shim_child}")
file(READ_SYMLINK "${shim_child}" existing_link_target)
if (existing_link_target STREQUAL real_child)
continue()
endif ()
file(REMOVE_RECURSE "${shim_child}")
elseif (EXISTS "${shim_child}")
# Do not delete real files/directories. This protects the generated shim.
continue()
endif ()
file(CREATE_LINK
"${real_child}"
"${shim_child}"
SYMBOLIC
)
endforeach ()
endfunction()
raptor_ensure_symlink(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
raptor_write_external_cmake_shim(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
"PIM accelerator"
)
raptor_ensure_symlink(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
raptor_write_external_cmake_shim(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
"PIM accelerator tests"
)
# Patch onnx-mlir sources for PIM accelerator support.
@@ -38,21 +106,21 @@ function(raptor_apply_patch file_path anchor replacement description)
# Already applied replacement text is present
string(FIND "${contents}" "${replacement}" already_applied_pos)
if(NOT already_applied_pos EQUAL -1)
if (NOT already_applied_pos EQUAL -1)
message(STATUS "Patch already applied: ${description}")
return()
endif()
endif ()
# Anchor must exist for the patch to be applicable
string(FIND "${contents}" "${anchor}" anchor_pos)
if(anchor_pos EQUAL -1)
if (anchor_pos EQUAL -1)
message(FATAL_ERROR
"Patch anchor not found onnx-mlir may have changed.\n"
" Patch : ${description}\n"
" File : ${file_path}\n"
" Anchor: ${anchor}"
)
endif()
endif ()
string(REPLACE "${anchor}" "${replacement}" patched "${contents}")
file(WRITE "${file_path}" "${patched}")
+296 -40
View File
@@ -1,65 +1,321 @@
# Raptor
Raptor is a domain-specific MLIR compiler for neural networks in ONNX format,
targeting in-memory computing / processing-in-memory (PIM) architectures. It
extends ONNX-MLIR with a PIM accelerator and progressively lowers ONNX-MLIR
through custom MLIR dialects to simulator artifacts.
The current target is the PIM simulator stack under `backend-simulators/pim`.
Raptor emits binary per-core `.pim` instruction files by default, plus
`memory.bin`, `config.json`, and weight binaries. It can also emit per-core JSON
instruction files with `--pim-emit-json`.
## Overview
PIM architectures perform most computation directly in memory. The supported
target models a chip with:
- shared host memory,
- multiple PIM cores,
- ReRAM crossbars for vector-matrix / matrix-vector work,
- explicit communication between cores,
- no hardware branch or loop support in emitted simulator code.
Because repeated work such as convolutions is eventually made explicit, emitted
instruction counts can grow quickly. Most compiler work therefore focuses on
lowering, scheduling, memory layout, and code-generation optimizations.
### Targets and simulators
- `backend-simulators/pim/pim-simulator` is the in-tree Rust functional
simulator used by validation. It reads Raptor's `pim/` artifact directory and
compares simulator output against native ONNX-MLIR execution.
- `backend-simulators/pim/pimsim-nn` is the performance simulator submodule.
The helper scripts in `pimcomp_utils/` are for comparison with PIMCOMP-NN and
contain local paths; treat them as local utilities, not portable workflows.
## Compilation pipeline
The PIM sources live under `src/PIM` and tests under `test/PIM`. CMake exposes
them to ONNX-MLIR through generated shim directories under
`onnx-mlir/src/Accelerators/PIM` and `onnx-mlir/test/accelerators/PIM`.
High-level lowering flow:
```
ONNX-MLIR -> Spatial -> Pim (tensor) -> Pim (bufferized) -> PIM artifacts
```
1. **ONNX -> Spatial** (`src/PIM/Conversion/ONNXToSpatial`).
Lowers supported ONNX ops into the `spat` dialect
(`src/PIM/Dialect/Spatial`). Conversion patterns are split by op family under
`Patterns/{Math,NN,Tensor}` and currently cover Conv, Gemm, MatMul,
elementwise Add/Mul/Div, ReduceMean, pooling, Relu, Sigmoid, Softmax,
Concat, Gather, Reshape, Resize, and Split.
2. **Merge compute nodes**
(`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`).
Builds a compute graph, schedules it with the PEFT scheduler, and materializes
the merge schedule into Spatial IR. Supporting scheduling code lives under
`MergeComputeNodes/Scheduling`.
3. **Spatial -> Pim** (`src/PIM/Conversion/SpatialToPim`).
Lowers Spatial operations to the `pim` dialect (`src/PIM/Dialect/Pim`),
including `pim.core`, `pim.core_batch`, communication, tensor packing, global
tensor materialization, and return-path normalization.
4. **Bufferization** (`src/PIM/Dialect/Pim/Transforms/Bufferization`).
Converts tensor-semantics PIM IR into memref-semantics PIM IR using MLIR's
bufferization interfaces.
5. **Static memory coalescing**
(`src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing`).
Reuses compatible local memref allocations inside PIM cores before codegen.
6. **PIM code generation** (`src/PIM/Pass/PimCodegen` and
`src/PIM/Compiler`).
Folds host constants, materializes remaining host constants, verifies PIM IR,
emits `.pim` core files, writes weights, and writes `memory.bin` /
`config.json`.
Supporting pieces:
- `src/PIM/Common` - shared IR, filesystem, diagnostics, reports, and utility
helpers.
- `src/PIM/Compiler` - PIM compiler options, memory/address planning, binary
instruction format, artifact writing, weight emission, and codegen entry
points.
- `src/PIM/Conversion/SpatialToGraphviz` - optional Spatial graphviz conversion
pass.
- `src/PIM/Pass` - pass registration and auxiliary passes.
- `src/PIM/PimAccelerator.{cpp,hpp}` - ONNX-MLIR accelerator entry point.
## Key compiler options
Pass these to `onnx-mlir` when compiling for PIM:
- `--maccel=PIM` - select the PIM accelerator.
- `--EmitSpatial`, `--EmitPim`, `--EmitPimBufferized`,
`--EmitPimCodegen` - stop the PIM pipeline at the requested stage. The PIM
default is `--EmitPimCodegen`.
- `--core-count=<N>` - required positive core count for PIM compilation.
- `--crossbar-size=<N>` - crossbar width/height. Default in code is `2`.
- `--crossbar-count=<N>` - crossbars per core. Default in code is `256`.
- `--pim-merge-scheduler=peft` - merge scheduler. `peft` is the only accepted
value in the current code.
- `--pim-only-codegen` - assume input is already bufferized PIM IR and only run
the codegen tail.
- `--pim-emit-json` - also emit `core_*.json` instruction files alongside
`core_*.pim`.
- `--use-experimental-conv-impl` - use the alternate convolution lowering.
- `--ignore-concat-error` - soft-fail a ConcatOp corner case.
Example:
```bash
./build_release/Release/bin/onnx-mlir model.onnx -o /tmp/raptor/model \
--maccel=PIM --EmitPimCodegen \
--crossbar-size=2048 --crossbar-count=256 --core-count=1000
```
This writes PIM artifacts under `/tmp/raptor/pim/`.
## Validation
Functional validation lives in `validation/`. It compiles ONNX models, builds a
native ONNX-MLIR reference runner, generates random inputs, runs Raptor, runs
the Rust PIM simulator, and compares outputs.
Python dependencies used by the validation scripts are `numpy`, `onnx`, and
`colorama`. The simulator requires the Rust toolchain.
Per-operation validation from the repository root:
```bash
python3 validation/validate.py \
--raptor-path build_release/Release/bin/onnx-mlir \
--onnx-include-dir onnx-mlir/include \
--core-count 1000
```
Validate one network or a subset by pointing `--operations-dir` at any directory
containing `.onnx` files:
```bash
python3 validation/validate.py \
--raptor-path build_release/Release/bin/onnx-mlir \
--onnx-include-dir onnx-mlir/include \
--operations-dir validation/networks/yolo11n/depth_04 \
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
```
Useful validation options:
- `--simulator-dir <path>` - override the auto-detected
`backend-simulators/pim/pim-simulator` path.
- `--threshold <float>` - maximum allowed per-element output difference.
- `--seed <int>` - RNG seed for generated inputs.
- `--command-timeout-seconds <float>` - timeout for compiler, runner, and
simulator subprocesses.
- `--verbose` - print subprocess logs and average PIM pass timings.
- `--clean` - remove generated validation artifacts and exit.
Each validation run writes artifacts in the model workspace, for example under
`validation/operations/gemm/small/`:
- `inputs/` - generated input CSV files.
- `outputs/` - native ONNX-MLIR reference outputs.
- `raptor/` - compiler artifacts, including `*.onnx.mlir`, dialect dumps under
`dialects/`, reports under `reports/`, and final PIM artifacts under `pim/`.
- `runner/` - generated reference runner source, build tree, and shared library.
- `simulation/out.bin` - raw simulator output used for comparison.
The compiler currently dumps dialect snapshots such as `spatial0.mlir`,
`spatial1_dcp_merged.mlir`, `pim0.mlir`, `pim1_buff.mlir`,
`pim2_coalesced.mlir`, and `pim3_folded.mlir` when an output directory is
available.
To rerun the simulator manually with tracing after validation has produced a
`raptor/pim/` directory:
```bash
cd backend-simulators/pim/pim-simulator
cargo run --no-default-features --features tracing --release \
--package pim-simulator --bin pim-simulator -- \
-f /path/to/workspace/raptor/pim \
-o /path/to/workspace/simulation/out.bin \
-d <addr0>,<size0>,<addr1>,<size1>,...
```
With `--features tracing`, the simulator writes per-core traces as
`TraceCore0`, `TraceCore1`, ... next to `out.bin`. The validator normally
computes the `-d` ranges from `raptor/pim/config.json` and model output shapes.
Available validation networks under `validation/networks/`: `vgg16`,
`yolo11n`, `yolo11nv2`.
Available operation suites under `validation/operations/`: `add`, `concat`,
`conv`, `div`, `gather`, `gemm`, `gemv`, `matmul`, `mul`, `pool`,
`reduce_mean`, `relu`, `reshape`, `resize`, `sigmoid`, `softmax`, `split`.
Generated operation tests can be regenerated with:
```bash
python3 validation/operations/gen_tests.py
```
## Build
Initialize submodules first:
```bash
git submodule update --init --recursive
```
The project follows ONNX-MLIR's build requirements. The CI workflow documents
the currently used versions and setup:
- CMake 4.3.0 in CI,
- LLVM/MLIR checked out under `onnx-mlir/llvm-project`,
- Protobuf `v34.0`,
- Rust stable for `pim-simulator`,
- Python packages `numpy`, `onnx`, `colorama` for validation.
### Protobuf
Use the following commands to install protobuf:
```
Install Protobuf if your system does not already provide a compatible version:
```bash
git clone --depth 1 --branch v34.0 https://github.com/protocolbuffers/protobuf
cd protobuf
mkdir build
cd build
cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release
ninja
sudo ninja install
cmake -S protobuf -B protobuf/build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-Dprotobuf_BUILD_TESTS=OFF
cmake --build protobuf/build
sudo cmake --install protobuf/build
```
You can now remove the protobuf repo directory with:
```
cd ../..
You can then remove the temporary checkout:
```bash
rm -rf protobuf
```
### Mlir
### MLIR
Follow the first part of instructions [here](onnx-mlir/docs/BuildOnLinuxOSX.md) to build mlir.
Follow the ONNX-MLIR instructions in
`onnx-mlir/docs/BuildOnLinuxOSX.md` to build LLVM/MLIR. The local Raptor build
expects `MLIR_DIR` to point at the MLIR CMake package, for example:
Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor
Moreover, if compiling with build type debug, it is also suggested to use
mold as linker (you will need to install it if you don't have it already)
to reduce memory usage during linking. You can use it by setting the options:
```
-DLLVM_USE_LINKER=mold
```bash
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build_release/lib/cmake/mlir
```
If your LLVM build directory is named `build` instead of `build_release`, adjust
the path accordingly.
### Raptor
Use the following commands to build Raptor.
Configure a release build:
Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor.
Also in this case, it is suggested to use mold as linker to reduce link time and memory usage,
setting the options:
```
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_MODULE_LINKER_FLAGS="-fuse-ld=mold"
```
```
git submodule update --init --recursive
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build/lib/cmake/mlir
mkdir build && cd build
cmake .. -G Ninja \
```bash
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build_release/lib/cmake/mlir
cmake -S . -B build_release -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DONNX_MLIR_ACCELERATORS=PIM \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_DIR=${MLIR_DIR}
cmake --build .
```
If the build fails because of protobuf missing uint definitions,
just patch the problematic files by adding ```#include <cstdint>``` to their includes.
Configure a debug build similarly:
```bash
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build_debug/lib/cmake/mlir
cmake -S . -B build_debug -G Ninja \
-DCMAKE_BUILD_TYPE=Debug \
-DONNX_MLIR_ACCELERATORS=PIM \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_DIR=${MLIR_DIR}
```
For debug development, using `mold` can reduce link time and memory use:
```bash
cmake -S . -B build_debug -G Ninja \
-DCMAKE_BUILD_TYPE=Debug \
-DONNX_MLIR_ACCELERATORS=PIM \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_DIR=${MLIR_DIR} \
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_MODULE_LINKER_FLAGS="-fuse-ld=mold"
```
Build the compiler with CMake:
```bash
cmake --build ./build_release
cmake --build ./build_debug
```
Do not invoke `ninja` directly for this project; use `cmake --build` so CMake's
configuration and generated shims stay consistent.
If a build fails because Protobuf headers are missing fixed-width integer
definitions, patch the affected Protobuf-generated files by adding
`#include <cstdint>`.
## Tests
The Rust simulator has its own tests:
```bash
cd backend-simulators/pim/pim-simulator
cargo test
```
## Repository Layout
- `src/PIM/` - PIM accelerator implementation.
- `test/PIM/` - PIM C++ unit tests.
- `validation/` - functional validation scripts, ONNX operation tests, network
slices, and pimsim config generation.
- `backend-simulators/pim/pim-simulator/` - in-tree Rust functional simulator.
- `backend-simulators/pim/pimsim-nn/` - performance simulator submodule.
- `pimcomp_utils/` - local comparison helpers for PIMCOMP-NN.
- `.github/actions/` and `.github/workflows/validate_operations.yml` - CI setup
for MLIR/Protobuf caching, building Raptor, and validation.
File diff suppressed because it is too large Load Diff
@@ -1,4 +1,3 @@
[package]
name = "pim-simulator"
version = "0.1.0"
@@ -13,8 +12,9 @@ name = "pimcore"
path = "src/lib/pimcore.rs"
[features]
default = ["tracing"]
default = []
tracing = []
profile_time = ["dep:plotly", "dep:comfy-table", "dep:statrs"]
@@ -27,3 +27,10 @@ hex = "0"
paste = "1"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
statrs = {version="0.16", optional=true}
comfy-table = {version="7.1", optional=true}
plotly = {version="0.8", optional=true}
rayon = "1.12.0"
faer = "0.24.0"
faer-traits = "0.24.0"
mimalloc = "0.1.50"
@@ -1,14 +1,20 @@
use mimalloc::MiMalloc;
#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;
use anyhow::{Context, Result, bail};
use clap::Parser;
use glob::glob;
use pimcore::binary_to_instruction::binary_to_executor;
use pimcore::cpu::crossbar::Crossbar;
use pimcore::json_to_instruction::json_to_executor;
use pimcore::memory_manager::CoreMemory;
use pimcore::tracing::TRACER;
use serde_json::Value;
use std::collections::HashMap;
use std::fs::{self, read_link};
use std::io::Write;
use std::fs::{self, File, read_link};
use std::io::{BufReader, Write};
use std::path::PathBuf;
/// Program to simulate core execution configuration
@@ -37,25 +43,31 @@ struct Args {
/// Comma separated list of (address,size) for memory output dump
#[arg(short, long, value_delimiter = ',', num_args = 1.., value_name = "ADDR,SIZE")]
dump: Vec<i32>,
dump: Vec<usize>,
}
fn main() -> Result<()> {
let args = Args::parse();
let config_json = retrive_config(&args)?;
let core_jsons = retrive_cores(&args)?;
let mut core_inputs = retrive_cores(&args)?;
let memory = retrive_memory(&args)?;
let global_crossbars = get_crossbars(&config_json, &args).unwrap();
let crossbars = map_crossbars_to_cores(&config_json, &args, &global_crossbars);
let mut executor =
json_to_executor::json_to_executor(config_json, core_jsons.iter(), crossbars);
let mut executor = match &mut core_inputs {
CoreInputs::Json(core_jsons) => {
json_to_executor::json_to_executor(config_json, core_jsons, crossbars)
}
CoreInputs::Binary(core_bins) => {
binary_to_executor(config_json, core_bins.iter(), crossbars)?
}
};
set_memory(&mut executor, memory);
TRACER
.lock()
.unwrap()
.init(executor.cpu().num_core(), args.output.clone());
executor.execute();
executor.execute()?;
dump_memory(executor, &args)?;
Ok(())
}
@@ -65,7 +77,7 @@ fn map_crossbars_to_cores<'c>(
args: &Args,
global_crossbars: &'c HashMap<String, Crossbar>,
) -> Vec<Vec<&'c Crossbar>> {
let mut res = Vec::new();
let mut res = vec![Vec::new()];
let num_cores = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
if let Some(folder) = args.folder.as_ref() {
@@ -140,8 +152,7 @@ fn get_crossbars(config: &Value, args: &Args) -> anyhow::Result<HashMap<String,
}
let bytes = std::fs::read(weight_file.path()).expect("Failed to read binary file");
let mut crossbar =
Crossbar::new(column_corssbar * 4, rows_crossbar, CoreMemory::new());
let mut crossbar = Crossbar::new(column_corssbar * 4, rows_crossbar, CoreMemory::new());
crossbar.execute_store(&bytes).unwrap();
res.insert(
weight_file
@@ -157,7 +168,7 @@ fn get_crossbars(config: &Value, args: &Args) -> anyhow::Result<HashMap<String,
}
fn dump_memory(mut executor: pimcore::Executable, args: &Args) -> Result<()> {
let dumps: Vec<(i32, i32)> = args
let dumps: Vec<(usize, usize)> = args
.dump
.chunks_exact(2)
.map(|chunk| (chunk[0], chunk[1]))
@@ -214,45 +225,82 @@ fn retrive_memory(args: &Args) -> Result<Vec<u8>> {
Ok(memory_vector)
}
fn retrive_cores(args: &Args) -> Result<Vec<Value>, anyhow::Error> {
let mut core_jsons: Vec<Value> = Vec::new();
if let Some(cores_override) = &args.cores {
for core in cores_override {
let content = fs::read_to_string(core)
.with_context(|| format!("Failed to read core file: {:?}", cores_override))?;
let json: Value =
serde_json::from_str(&content).context("Failed to parse core json override")?;
core_jsons.push(json);
}
} else if let Some(folder) = args.folder.as_ref() {
let pattern = folder.join("core*.json");
let pattern_str = pattern.to_str().context("Invalid path encoding")?;
let mut paths: Vec<_> = glob(pattern_str)?.map(|x| x.unwrap()).collect();
paths.sort_by_cached_key(|x| {
let mut x = x
.file_stem()
.expect("Extracting the stem")
.to_str()
.expect("File not utf-8");
x = &x[5..];
x.parse::<i32>().unwrap()
});
enum CoreInputs {
Json(Vec<BufReader<File>>),
Binary(Vec<Vec<u8>>),
}
if paths.is_empty() {
bail!("No core*.json files found in {:?}", folder);
fn retrive_cores(args: &Args) -> Result<CoreInputs, anyhow::Error> {
if let Some(cores_override) = &args.cores {
let first_extension = cores_override
.first()
.and_then(|path| path.extension())
.and_then(|ext| ext.to_str())
.unwrap_or_default();
if first_extension == "pim" {
let mut core_bins = Vec::with_capacity(cores_override.len());
for core in cores_override {
core_bins.push(
fs::read(core)
.with_context(|| format!("Failed to read binary core file: {:?}", core))?,
);
}
return Ok(CoreInputs::Binary(core_bins));
}
for entry in paths {
let path = entry;
let content = fs::read_to_string(&path)
.with_context(|| format!("Failed to read core file: {:?}", path))?;
let json: Value = serde_json::from_str(&content)
.with_context(|| format!("Failed to parse JSON in {:?}", path))?;
core_jsons.push(json);
let mut core_jsons_reader: Vec<BufReader<File>> = Vec::with_capacity(cores_override.len());
for core in cores_override {
let file = File::open(core)?;
let reader = BufReader::new(file);
core_jsons_reader.push(reader);
}
} else {
bail!("Either --core or --folder must be provided to find core definitions.");
return Ok(CoreInputs::Json(core_jsons_reader));
}
Ok(core_jsons)
if let Some(folder) = args.folder.as_ref() {
let binary_pattern = folder.join("core*.pim");
let binary_pattern_str = binary_pattern.to_str().context("Invalid path encoding")?;
let mut binary_paths: Vec<_> = glob(binary_pattern_str)?.map(|x| x.unwrap()).collect();
binary_paths.sort_by_cached_key(core_sort_key);
if !binary_paths.is_empty() {
let mut core_bins = Vec::with_capacity(binary_paths.len());
for path in binary_paths {
core_bins.push(
fs::read(&path)
.with_context(|| format!("Failed to read core file: {:?}", path))?,
);
}
return Ok(CoreInputs::Binary(core_bins));
}
let json_pattern = folder.join("core*.json");
let json_pattern_str = json_pattern.to_str().context("Invalid path encoding")?;
let mut json_paths: Vec<_> = glob(json_pattern_str)?.map(|x| x.unwrap()).collect();
json_paths.sort_by_cached_key(core_sort_key);
if json_paths.is_empty() {
bail!("No core*.pim or core*.json files found in {:?}", folder);
}
let mut core_json_reader: Vec<BufReader<File>> = Vec::with_capacity(json_paths.len());
for path in json_paths {
let file = File::open(path)?;
let reader = BufReader::new(file);
core_json_reader.push(reader);
}
return Ok(CoreInputs::Json(core_json_reader));
}
bail!("Either --core or --folder must be provided to find core definitions.");
}
fn core_sort_key(path: &PathBuf) -> i32 {
let mut stem = path
.file_stem()
.expect("Extracting the stem")
.to_str()
.expect("File not utf-8");
stem = &stem[5..];
stem.parse::<i32>().unwrap()
}
fn retrive_config(args: &Args) -> Result<Value, anyhow::Error> {
@@ -0,0 +1,497 @@
use crate::{
CoreInstructionsBuilder, Executable,
cpu::{CPU, crossbar::Crossbar},
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
};
use anyhow::{Context, Result, bail, ensure};
use serde_json::Value;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::sync::LazyLock;
const MAGIC: &[u8; 4] = b"PIMB";
const VERSION: u32 = 1;
const HEADER_SIZE: usize = 12;
const RECORD_SIZE: usize = 20;
macro_rules! add_name {
($storage:ident, $opcode:literal, $name:literal) => {
$storage.insert($opcode, $name);
};
}
static INSTRUCTIONS: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
let mut hash = HashMap::new();
add_name!(hash, 0, "nop");
add_name!(hash, 1, "sldi");
add_name!(hash, 2, "sld");
add_name!(hash, 3, "sadd");
add_name!(hash, 4, "ssub");
add_name!(hash, 5, "smul");
add_name!(hash, 6, "saddi");
add_name!(hash, 7, "smuli");
add_name!(hash, 8, "setbw");
add_name!(hash, 9, "mvmul");
add_name!(hash, 10, "vvadd");
add_name!(hash, 11, "vvsub");
add_name!(hash, 12, "vvmul");
add_name!(hash, 13, "vvdmul");
add_name!(hash, 14, "vvmax");
add_name!(hash, 15, "vvsll");
add_name!(hash, 16, "vvsra");
add_name!(hash, 17, "vavg");
add_name!(hash, 18, "vrelu");
add_name!(hash, 19, "vtanh");
add_name!(hash, 20, "vsigm");
add_name!(hash, 21, "vsoftmax");
add_name!(hash, 22, "vmv");
add_name!(hash, 23, "vrsu");
add_name!(hash, 24, "vrsl");
add_name!(hash, 25, "ld");
add_name!(hash, 26, "st");
add_name!(hash, 27, "lldi");
add_name!(hash, 28, "lmv");
add_name!(hash, 29, "send");
add_name!(hash, 30, "recv");
add_name!(hash, 31, "wait");
add_name!(hash, 32, "sync");
hash
});
#[derive(Clone, Copy, Debug, Default)]
struct InstructionRecord {
opcode: u8,
rd: u8,
r1: u8,
r2_or_imm: i32,
generic1: i32,
generic2: i32,
generic3: i32,
flags: u8,
}
fn read_u32_le(bytes: &[u8], offset: usize) -> u32 {
u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap())
}
fn read_i32_le(bytes: &[u8], offset: usize) -> i32 {
i32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap())
}
fn parse_binary_records(bytes: &[u8]) -> Result<Vec<InstructionRecord>> {
ensure!(bytes.len() >= HEADER_SIZE, "binary core file too small");
ensure!(&bytes[0..4] == MAGIC, "invalid PIM binary magic");
let version = read_u32_le(bytes, 4);
ensure!(
version == VERSION,
"unsupported PIM binary version {version}"
);
let instruction_count = read_u32_le(bytes, 8) as usize;
let expected_len = HEADER_SIZE + instruction_count * RECORD_SIZE;
ensure!(
bytes.len() == expected_len,
"PIM binary size mismatch: expected {expected_len} bytes, got {}",
bytes.len()
);
let mut records = Vec::with_capacity(instruction_count);
for index in 0..instruction_count {
let base = HEADER_SIZE + index * RECORD_SIZE;
records.push(InstructionRecord {
opcode: bytes[base],
rd: bytes[base + 1],
r1: bytes[base + 2],
flags: bytes[base + 3],
r2_or_imm: read_i32_le(bytes, base + 4),
generic1: read_i32_le(bytes, base + 8),
generic2: read_i32_le(bytes, base + 12),
generic3: read_i32_le(bytes, base + 16),
});
}
Ok(records)
}
fn append_record(
inst_builder: &mut InstructionsBuilder,
inst_data_builder: &mut InstructionDataBuilder,
record: InstructionRecord,
) -> Result<()> {
let InstructionRecord {
opcode,
rd,
r1,
r2_or_imm,
generic1,
generic2,
generic3,
flags: _,
} = record;
match opcode {
0 => {}
1 => {
inst_data_builder.set_rd_u8(rd).set_imm(r2_or_imm);
inst_builder.make_inst(sldi, inst_data_builder.build());
}
2 => {
inst_data_builder
.set_rd_u8(rd)
.set_r1_u8(r1)
.set_offset_select(generic1)
.set_offset_value(generic2);
inst_builder.make_inst(sld, inst_data_builder.build());
}
3 => {
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
inst_builder.make_inst(sadd, inst_data_builder.build());
}
4 => {
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
inst_builder.make_inst(ssub, inst_data_builder.build());
}
5 => {
inst_data_builder.set_rdr1r2_u8(rd, r1, r2_or_imm);
inst_builder.make_inst(smul, inst_data_builder.build());
}
6 => {
inst_data_builder.set_rdr1imm_u8(rd, r1, r2_or_imm);
inst_builder.make_inst(saddi, inst_data_builder.build());
}
7 => {
inst_data_builder.set_rdr1imm_u8(rd, r1, r2_or_imm);
inst_builder.make_inst(smuli, inst_data_builder.build());
}
8 => {
inst_data_builder.set_ibiw_obiw(generic1, generic2);
inst_builder.make_inst(setbw, inst_data_builder.build());
}
9 => {
inst_data_builder
.set_rd_u8(rd)
.set_r1_u8(r1)
.set_mbiw_immrelu_immgroup(r2_or_imm, generic1, generic2);
inst_builder.make_inst(mvmul, inst_data_builder.build());
}
10 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vvadd, inst_data_builder.build());
}
11 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vvsub, inst_data_builder.build());
}
12 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vvmul, inst_data_builder.build());
}
13 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vvdmul, inst_data_builder.build());
}
14 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vvmax, inst_data_builder.build());
}
15 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vvsll, inst_data_builder.build());
}
16 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vvsra, inst_data_builder.build());
}
17 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vavg, inst_data_builder.build());
}
18 => {
inst_data_builder
.set_rdr1_u8(rd, r1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vrelu, inst_data_builder.build());
}
19 => {
inst_data_builder
.set_rdr1_u8(rd, r1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vtanh, inst_data_builder.build());
}
20 => {
inst_data_builder
.set_rdr1_u8(rd, r1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vsigm, inst_data_builder.build());
}
21 => {
inst_data_builder
.set_rdr1_u8(rd, r1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vsoftmax, inst_data_builder.build());
}
22 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vmv, inst_data_builder.build());
}
23 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vrsu, inst_data_builder.build());
}
24 => {
inst_data_builder
.set_rdr1r2_u8(rd, r1, r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(vrsl, inst_data_builder.build());
}
25 => {
inst_data_builder
.set_rdr1_u8(rd, r1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(ld, inst_data_builder.build());
}
26 => {
inst_data_builder
.set_rdr1_u8(rd, r1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(st, inst_data_builder.build());
}
27 => {
inst_data_builder
.set_rd_u8(rd)
.set_imm(r2_or_imm)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(lldi, inst_data_builder.build());
}
28 => {
inst_data_builder
.set_rdr1_u8(rd, r1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(lmv, inst_data_builder.build());
}
29 => {
inst_data_builder
.set_rd_u8(rd)
.set_imm_core(r2_or_imm + 1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(send, inst_data_builder.build());
}
30 => {
inst_data_builder
.set_rd_u8(rd)
.set_imm_core(r2_or_imm + 1)
.set_imm_len(generic3)
.set_offset_select_value(generic1, generic2);
inst_builder.make_inst(recv, inst_data_builder.build());
}
31 => {
inst_builder.make_inst(wait, inst_data_builder.build());
}
32 => {
inst_builder.make_inst(sync, inst_data_builder.build());
}
_ => bail!("unsupported PIM binary opcode {opcode}"),
}
Ok(())
}
fn binary_to_instructions(
core_bytes: &[u8],
core_index: i32,
) -> Result<Vec<crate::instruction_set::Instruction>> {
let records = parse_binary_records(core_bytes)?;
let mut insts_builder = InstructionsBuilder::new();
let mut inst_data_builder = InstructionDataBuilder::new();
inst_data_builder
.set_core_indx_u16(u16::try_from(core_index).expect("core index does not fit in u16"))
.fix_core_indx();
for record in records {
let opcode = record.opcode;
let name = INSTRUCTIONS
.get(&(opcode as usize))
.copied()
.unwrap_or("<unknown>");
append_record(&mut insts_builder, &mut inst_data_builder, record).with_context(|| {
format!(
"while decoding binary instruction for core {core_index}: opcode {opcode} ({name})"
)
})?;
}
Ok(insts_builder.build())
}
pub fn binary_to_executor<'a, 'b>(
config: Value,
cores: impl Iterator<Item = &'b Vec<u8>>,
crossbars: Vec<Vec<&'a Crossbar>>,
) -> Result<Executable<'a>> {
let core_cnt = config
.get("core_cnt")
.context("missing core_cnt in config")?
.as_i64()
.context("core_cnt is not an integer")? as i32;
let cpu = CPU::new(core_cnt, crossbars);
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
for (external_core_indx, core_bytes) in cores.enumerate() {
let core_indx = external_core_indx as i32 + 1;
let instructions = binary_to_instructions(core_bytes, core_indx)?;
core_insts_builder.set_core(core_indx, instructions);
}
Ok(Executable::new(cpu, core_insts_builder.build()))
}
#[cfg(test)]
mod tests {
use super::{
HEADER_SIZE, InstructionRecord, MAGIC, RECORD_SIZE, VERSION, binary_to_instructions,
};
use crate::{
functor_to_name,
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
json_to_instruction::json_isa::json_to_instruction,
};
fn encode_record(record: InstructionRecord, dst: &mut Vec<u8>) {
dst.push(record.opcode);
dst.push(record.rd);
dst.push(record.r1);
dst.push(record.flags);
dst.extend_from_slice(&record.r2_or_imm.to_le_bytes());
dst.extend_from_slice(&record.generic1.to_le_bytes());
dst.extend_from_slice(&record.generic2.to_le_bytes());
dst.extend_from_slice(&record.generic3.to_le_bytes());
}
fn binary_blob(records: &[InstructionRecord]) -> Vec<u8> {
let mut blob = Vec::with_capacity(HEADER_SIZE + records.len() * RECORD_SIZE);
blob.extend_from_slice(MAGIC);
blob.extend_from_slice(&VERSION.to_le_bytes());
blob.extend_from_slice(&(records.len() as u32).to_le_bytes());
for &record in records {
encode_record(record, &mut blob);
}
blob
}
#[test]
fn json_and_binary_decoders_match_for_representative_ops() {
let json_program = [
r#"{"imm":64,"op":"sldi","rd":0}"#,
r#"{"imm":128,"op":"sldi","rd":1}"#,
r#"{"len":16,"offset":{"offset_select":0,"offset_value":0},"op":"lmv","rd":0,"rs1":1}"#,
r#"{"group":3,"mbiw":8,"op":"mvmul","rd":0,"relu":0,"rs1":1}"#,
r#"{"len":16,"offset":{"offset_select":0,"offset_value":0},"op":"vvadd","rd":0,"rs1":1,"rs2":2}"#,
r#"{"core":2,"offset":{"offset_select":0,"offset_value":0},"op":"send","rd":0,"size":16}"#,
];
let binary_program = binary_blob(&[
InstructionRecord {
opcode: 1,
rd: 0,
r2_or_imm: 64,
..Default::default()
},
InstructionRecord {
opcode: 1,
rd: 1,
r2_or_imm: 128,
..Default::default()
},
InstructionRecord {
opcode: 28,
rd: 0,
r1: 1,
generic3: 16,
..Default::default()
},
InstructionRecord {
opcode: 9,
rd: 0,
r1: 1,
r2_or_imm: 8,
generic2: 3,
..Default::default()
},
InstructionRecord {
opcode: 10,
rd: 0,
r1: 1,
r2_or_imm: 2,
generic3: 16,
..Default::default()
},
InstructionRecord {
opcode: 29,
rd: 0,
r2_or_imm: 2,
generic3: 16,
..Default::default()
},
]);
let mut json_builder = InstructionsBuilder::new();
let mut json_data_builder = InstructionDataBuilder::new();
json_data_builder.set_core_indx(1).fix_core_indx();
for inst in json_program {
let value = serde_json::from_str(inst).unwrap();
json_to_instruction(&mut json_builder, &mut json_data_builder, &value);
}
let json_instructions = json_builder.build();
let binary_instructions = binary_to_instructions(&binary_program, 1).unwrap();
assert_eq!(json_instructions.len(), binary_instructions.len());
for (json_inst, binary_inst) in json_instructions.iter().zip(binary_instructions.iter()) {
assert_eq!(
functor_to_name(json_inst.functor as usize),
functor_to_name(binary_inst.functor as usize)
);
assert_eq!(json_inst.data, binary_inst.data);
}
}
}
@@ -1,3 +1,4 @@
use crate::utility::AddressArg;
use std::{collections::HashMap, fmt::Debug};
use anyhow::{Context, Result, ensure};
@@ -9,6 +10,7 @@ use crate::{
pub mod crossbar;
#[derive(Debug, Clone)]
pub struct CPU<'a> {
cores: Box<[Core<'a>]>,
@@ -91,30 +93,26 @@ impl<'a> Core<'a> {
self.memory.execute_load()
}
pub fn execute_store<T>(&mut self, address: impl TryToUsize, element: &[T]) -> Result<()>
pub fn execute_store<T>(&mut self, address: impl AddressArg, element: &[T]) -> Result<()>
where
T: MemoryStorable,
{
let address = address.try_into().context("address can not be negative")?;
let address = address.to_address_usize()?;
self.memory.execute_store(address, element)
}
pub fn reserve_load(
&mut self,
address: impl TryToUsize,
address: impl AddressArg,
size: impl TryToUsize,
) -> Result<&mut CoreMemory> {
let address = address.try_into().context("address can not be negative")?;
let address = address.to_address_usize()?;
let size = size.try_into().context("size can not be negative")?;
self.memory.reserve_load(address, size)
}
pub fn set_register(&mut self, index: impl TryToUsize, value: i32) {
let index = index.try_into().expect("index can not be negative");
assert!(
value >= 0,
"Register cannot be negative if happens remove this and go check where it's used as usize"
);
self.registers[index] = value;
}
@@ -123,11 +121,11 @@ impl<'a> Core<'a> {
self.registers[index]
}
pub fn load<T>(&mut self, address: impl TryToUsize, size: impl TryToUsize) -> Result<Vec<&[T]>>
pub fn load<T>(&mut self, address: impl AddressArg, size: impl TryToUsize) -> Result<Vec<&[T]>>
where
T: MemoryStorable,
{
let address = address.try_into().context("address can not be negative")?;
let address = address.to_address_usize()?;
let size = size.try_into().context("size can not be negative")?;
self.memory.load(address, size)
}
@@ -141,8 +139,8 @@ impl<'a> Core<'a> {
(memory, crossbars)
}
pub fn memset(&mut self, address: impl TryToUsize, size: impl TryToUsize, val: u8) -> Result<()> {
let address = address.try_into().context("address can not be negative")?;
pub fn memset(&mut self, address: impl AddressArg, size: impl TryToUsize, val: u8) -> Result<()> {
let address = address.to_address_usize()?;
let size = size.try_into().context("size can not be negative")?;
self.memory.memset(address, size, val)
}
@@ -1,10 +1,11 @@
use paste::paste;
use std::convert::TryFrom;
#[derive(Clone, Copy, Debug, Default)]
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct InstructionData {
core_indx: i32,
rd: i32,
r1: i32,
core_indx: u16,
rd: u8,
r1: u8,
//r2 imm mbiw imm_core
r2_or_imm: i32,
//offset_select imm_relu ibiw
@@ -16,18 +17,30 @@ pub struct InstructionData {
}
impl InstructionData {
pub fn core_indx(&self) -> i32 {
pub fn core_indx_u16(&self) -> u16 {
self.core_indx
}
pub fn rd(&self) -> i32 {
pub fn core_indx(&self) -> i32 {
i32::from(self.core_indx)
}
pub fn rd_u8(&self) -> u8 {
self.rd
}
pub fn r1(&self) -> i32 {
pub fn rd(&self) -> i32 {
i32::from(self.rd)
}
pub fn r1_u8(&self) -> u8 {
self.r1
}
pub fn r1(&self) -> i32 {
i32::from(self.r1)
}
pub fn r2(&self) -> i32 {
self.r2_or_imm
}
@@ -49,26 +62,26 @@ impl InstructionData {
}
pub fn get_core_rd_r1(&self) -> (i32, i32, i32) {
(self.core_indx, self.rd, self.r1)
(self.core_indx(), self.rd(), self.r1())
}
pub fn get_core_rd_r1_r2(&self) -> (i32, i32, i32, i32) {
(self.core_indx, self.rd, self.r1, self.r2_or_imm)
(self.core_indx(), self.rd(), self.r1(), self.r2_or_imm)
}
pub fn get_core_rd_imm(&self) -> (i32, i32, i32) {
(self.core_indx, self.rd, self.r2_or_imm)
(self.core_indx(), self.rd(), self.r2_or_imm)
}
pub fn get_core_rd_r1_imm(&self) -> (i32, i32, i32, i32) {
(self.core_indx, self.rd, self.r1, self.r2_or_imm)
(self.core_indx(), self.rd(), self.r1(), self.r2_or_imm)
}
pub fn get_core_rd_r1_r2_immlen_offset(&self) -> (i32, i32, i32, i32, i32, i32, i32) {
(
self.core_indx,
self.rd,
self.r1,
self.core_indx(),
self.rd(),
self.r1(),
self.r2_or_imm,
self.generic3,
self.generic1,
@@ -78,9 +91,9 @@ impl InstructionData {
pub fn get_core_rd_r1_mbiw_immrelu_immgroup(&self) -> (i32, i32, i32, i32, i32, i32) {
(
self.core_indx,
self.rd,
self.r1,
self.core_indx(),
self.rd(),
self.r1(),
self.r2_or_imm,
self.generic1,
self.generic2,
@@ -100,7 +113,7 @@ impl InstructionData {
}
pub(crate) fn get_core_immcore(&self) -> (i32, i32) {
(self.core_indx, self.r2_or_imm)
(self.core_indx(), self.r2_or_imm)
}
}
@@ -216,6 +229,18 @@ impl InstructionDataBuilder {
common_getter_setter![imm_group];
common_getter_setter![imm_core];
pub fn set_core_indx_u16(&mut self, val: u16) -> &mut Self {
self.set_core_indx(i32::from(val))
}
pub fn set_rd_u8(&mut self, val: u8) -> &mut Self {
self.set_rd(i32::from(val))
}
pub fn set_r1_u8(&mut self, val: u8) -> &mut Self {
self.set_r1(i32::from(val))
}
pub fn new() -> Self {
Self {
core_indx: Fixer::Edit(0),
@@ -254,20 +279,16 @@ impl InstructionDataBuilder {
fn check_sanity(&self) {
assert!(!(self.get_r2() != 0 && self.get_imm() != 0 && self.get_mbiw() != 0 && self.get_imm_core() != 0));
assert!(
!(self.get_ibiw() != 0 && self.get_offset_select() != 0 && self.get_imm_relu() != 0)
);
assert!(
!(self.get_obiw() != 0 && self.get_offset_value() != 0 && self.get_imm_group() != 0)
);
assert!(!(self.get_ibiw() != 0 && self.get_offset_select() != 0 && self.get_imm_relu() != 0));
assert!(!(self.get_obiw() != 0 && self.get_offset_value() != 0 && self.get_imm_group() != 0));
}
pub fn build(&mut self) -> InstructionData {
self.check_sanity();
let inst_data = InstructionData {
core_indx: self.get_core_indx(),
rd: self.get_rd(),
r1: self.get_r1(),
core_indx: u16::try_from(self.get_core_indx()).expect("core index does not fit in u16"),
rd: u8::try_from(self.get_rd()).expect("rd does not fit in u8"),
r1: u8::try_from(self.get_r1()).expect("r1 does not fit in u8"),
r2_or_imm: self.get_r2() + self.get_imm() + self.get_mbiw() + self.get_imm_core(),
generic1: self.get_offset_select() + self.get_ibiw() + self.get_imm_relu(),
generic2: self.get_offset_value() + self.get_obiw() + self.get_imm_group(),
@@ -281,6 +302,10 @@ impl InstructionDataBuilder {
self.set_rd(rd).set_r1(r1).set_r2(r2)
}
pub fn set_rdr1r2_u8(&mut self, rd: u8, r1: u8, r2: i32) -> &mut Self {
self.set_rd_u8(rd).set_r1_u8(r1).set_r2(r2)
}
pub fn set_offset_select_value(&mut self, offset_select: i32, offset_value: i32) -> &mut Self {
self.set_offset_select(offset_select)
.set_offset_value(offset_value)
@@ -290,14 +315,26 @@ impl InstructionDataBuilder {
self.set_rd(rd).set_r1(r1).set_imm(imm)
}
pub fn set_rdr1imm_u8(&mut self, rd: u8, r1: u8, imm: i32) -> &mut Self {
self.set_rd_u8(rd).set_r1_u8(r1).set_imm(imm)
}
pub fn set_rdr1(&mut self, rd: i32, r1: i32) -> &mut Self {
self.set_rd(rd).set_r1(r1)
}
pub fn set_rdr1_u8(&mut self, rd: u8, r1: u8) -> &mut Self {
self.set_rd_u8(rd).set_r1_u8(r1)
}
pub fn set_rdimm(&mut self, rd: i32, imm: i32) -> &mut Self {
self.set_rd(rd).set_imm(imm)
}
pub fn set_rdimm_u8(&mut self, rd: u8, imm: i32) -> &mut Self {
self.set_rd_u8(rd).set_imm(imm)
}
pub fn set_ibiw_obiw(&mut self, ibiw: i32, obiw: i32) -> &mut Self {
self.set_ibiw(ibiw).set_obiw(obiw)
}
@@ -1,17 +1,22 @@
use crate::{
cpu::{CPU, crossbar}, instruction_set::{
cpu::{CPU, crossbar},
instruction_set::{
Instruction, InstructionData, InstructionStatus, InstructionType, VectorBitWith,
helper::add_all,
}, memory_manager::{
},
memory_manager::{
MemoryStorable,
type_traits::{FromFloat, UpcastDestTraits, UpcastSlice},
}, tracing::TRACER, utility::{add_offset_r1, add_offset_r2, add_offset_rd}
},
tracing::TRACER,
utility::{add_offset_r1, add_offset_r2, add_offset_rd},
};
use aligned_vec::{AVec, ConstAlign};
use anyhow::{Context, Result, ensure};
use rayon::prelude::*;
use paste::paste;
use std::{borrow::Cow, cell::OnceCell, collections::HashMap};
use std::{borrow::Cow, cell::OnceCell, collections::HashMap };
use std::{collections::HashSet, sync::LazyLock};
macro_rules! add_name {
@@ -30,7 +35,7 @@ macro_rules! add_name_simd {
};
}
static NAMES: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
pub static NAMES: LazyLock<HashMap<usize, &'static str>> = LazyLock::new(|| {
let mut hash = HashMap::new();
add_name!(hash, sldi);
add_name!(hash, sld);
@@ -76,8 +81,8 @@ pub fn functor_to_name(functor: usize) -> &'static str {
///////////////////////////////////////////////////////////////
/////////////////Scalar/register Instructions//////////////////
///////////////////////////////////////////////////////////////
pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
{
#[inline(never)]
pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_sldi(cores, data);
let (core_indx, rd, imm) = data.get_core_rd_imm();
let core = cores.core(core_indx);
@@ -86,6 +91,7 @@ pub fn sldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn sld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_sld(cores, data);
let (core_indx, rd, r1) = data.get_core_rd_r1();
@@ -100,6 +106,7 @@ pub fn sld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn sadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_sadd(cores, data);
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
@@ -110,6 +117,7 @@ pub fn sadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn ssub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_ssub(cores, data);
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
@@ -120,6 +128,7 @@ pub fn ssub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn smul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_smul(cores, data);
let (core_indx, rd, r1, r2) = data.get_core_rd_r1_r2();
@@ -130,6 +139,7 @@ pub fn smul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn saddi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_saddi(cores, data);
let (core_indx, rd, r1, imm) = data.get_core_rd_r1_imm();
@@ -139,6 +149,7 @@ pub fn saddi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn smuli(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_smuli(cores, data);
let (core_indx, rd, r1, imm) = data.get_core_rd_r1_imm();
@@ -213,14 +224,17 @@ pub fn is_setbw(functor: InstructionType) -> bool {
functor as usize == setbw as *const () as usize
}
#[inline(never)]
pub fn setbw(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, this instruction is resolved in the construction phase");
}
#[inline(never)]
pub fn mvmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn mvm_impl_internal<F, M, T>(
cores: &mut CPU,
data: InstructionData,
@@ -229,25 +243,30 @@ where
[F]: UpcastSlice<T> + UpcastSlice<M>,
[M]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
// Add faer::ComplexField HERE, directly bounding M for this function only
M: UpcastDestTraits<M> + MemoryStorable + FromFloat + faer_traits::ComplexField,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_mvm::<F,M,T>(cores, data);
TRACER.lock().unwrap().pre_mvm::<F, M, T>(cores, data);
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
let group: usize = group.try_into().context("group can not be negative")?;
let core = cores.core(core_indx);
let r1_val = core.register(r1);
let rd_val = core.register(rd);
let (memory, crossbars) = core.get_memory_crossbar();
let crossbar = crossbars.get_mut(group).unwrap();
let crossbar_stored_bytes = crossbar.stored_bytes();
let crossbar_byte_width = crossbar.width();
//Fix this
let crossbar_elem_width = crossbar_byte_width / size_of::<M>();
ensure!(
crossbar_byte_width & size_of::<M>() == 0,
crossbar_byte_width % size_of::<M>() == 0,
"M not divisor of the crosbbar size"
);
let crossbar_height = crossbar.height();
let crossbar_byte_size = crossbar_byte_width * crossbar_height;
@@ -257,19 +276,29 @@ where
let load = loads[0];
let vec: Cow<[M]> = load.up();
let matrix = crossbar.load::<M>(crossbar_byte_size)?[0];
let mut res = Vec::with_capacity(crossbar_elem_width);
let mut partial :AVec<M, _> = AVec::<M, ConstAlign<64>>::with_capacity(64, vec.len());
partial.resize(vec.len(), M::from_f32(0.0));
for x in 0..crossbar_elem_width {
partial[0] = vec[0] * matrix[x];
for y in 1..crossbar_height {
partial[y] = vec[y] * matrix[y * crossbar_elem_width + x];
}
// --- FAER IMPLEMENTATION ---
// 1. Explicitly create a Matrix Reference (MatRef)
let matrix_view = faer::mat::MatRef::from_row_major_slice(
matrix.as_ref(),
crossbar_height,
crossbar_elem_width,
);
// 2. Explicitly create a Column Vector Reference (ColRef)
// Using `ColRef` here guarantees we don't accidentally get a RowRef (Fixes E0277)
let vec_view = faer::col::ColRef::from_slice(vec.as_ref());
let res_col: faer::col::Col<M> = matrix_view.transpose() * vec_view;
// 4. Convert back to standard Rust Vec
// try_as_slice() returns an Option<&[M]>.
// We can safely unwrap() because a freshly allocated, owned Col is ALWAYS contiguous!
let mut res: Vec<M> = (0..crossbar_elem_width).map(|i| res_col[i]).collect();
// --- END FAER ---
let mut acc = add_all(partial.as_slice());
res.push(acc);
}
if relu != 0 {
res.iter_mut().for_each(|x| {
if *x < M::from_f32(0.0) {
@@ -277,16 +306,20 @@ where
}
});
}
ensure!(
res.len() == crossbar_elem_width,
"mvm generate a vector bigger thant it's requested elements"
"mvm generate a vector bigger thant it's requested elements"
);
let res_up: Cow<[T]> = res.as_slice().up();
core.execute_store(rd_val, res_up.as_ref());
TRACER.lock().unwrap().post_mvm::<F,M,T>(cores, data);
TRACER.lock().unwrap().post_mvm::<F, M, T>(cores, data);
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub(super) fn mvmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T> + UpcastSlice<f32> + UpcastSlice<f64>,
@@ -307,17 +340,19 @@ where
}
}
#[inline(never)]
pub fn vvadd(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vvadd_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vvadd::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vvadd::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -345,21 +380,23 @@ where
);
let res_up: Cow<[T]> = res.as_slice().up();
core.execute_store(rd_val, res_up.as_ref());
TRACER.lock().unwrap().post_vvadd::<F,T>(cores, data);
TRACER.lock().unwrap().post_vvadd::<F, T>(cores, data);
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vvsub(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vvsub_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vvsub::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vvsub::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -394,13 +431,14 @@ pub fn vvmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vvmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vvmul::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vvmul::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -430,17 +468,19 @@ where
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vvdmul(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vvdmul_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vvdmul::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vvdmul::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -466,17 +506,19 @@ where
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vvmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vvmax_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vvmax::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vvmax::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -503,29 +545,33 @@ where
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vvsll(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!(
"Shift left on floating point what does it means? who has generated this instruction???"
);
}
#[inline(never)]
pub fn vvsra(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!(
"Shift right on floating point what does it means? who has generated this instruction???"
);
}
#[inline(never)]
pub fn vavg(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vavg_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
TRACER.lock().unwrap().pre_vavg::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vavg::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -533,7 +579,10 @@ where
let r2_val = r2;
ensure!(r2_val == 1, "Stride different than 1 not supported");
let rd_val = core.register(rd);
ensure!(offset_select == 1, "Offset select cannot be different from 1");
ensure!(
offset_select == 1,
"Offset select cannot be different from 1"
);
let r1_val = add_offset_r1(r1_val, offset_select, offset_value);
let loads = core.reserve_load(r1_val, imm_len)?.execute_load::<F>()?;
let load1 = loads[0];
@@ -545,17 +594,19 @@ where
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vrelu(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vrelu_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
TRACER.lock().unwrap().pre_vrelu::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vrelu::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -575,17 +626,19 @@ where
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vtanh(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vtanh_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
TRACER.lock().unwrap().pre_vtanh::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vtanh::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -603,17 +656,19 @@ where
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vsigm(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
#[inline(never)]
pub(super) fn vsigm_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
TRACER.lock().unwrap().pre_vsigm::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vsigm::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -629,17 +684,22 @@ where
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vsoftmax(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
panic!("You are calling a placeholder, the real call is the generic version");
}
pub(super) fn vsoftmax_impl<F, T>(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
#[inline(never)]
pub(super) fn vsoftmax_impl<F, T>(
cores: &mut CPU,
data: InstructionData,
) -> Result<InstructionStatus>
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
TRACER.lock().unwrap().pre_vsoftmax::<F,T>(cores, data);
TRACER.lock().unwrap().pre_vsoftmax::<F, T>(cores, data);
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let core = cores.core(core_indx);
@@ -656,27 +716,29 @@ where
.reduce(|a, b| if a > b { a } else { b })
.unwrap();
let exp_values: Vec<F> = load1.iter().map(|&a| (a - max_val).exp()).collect();
let sum = exp_values
.iter()
.copied()
.reduce(|a, b| a + b)
.unwrap();
ensure!(sum > 0.0.into(), "vsoftmax normalization sum must be positive");
let sum = exp_values.iter().copied().reduce(|a, b| a + b).unwrap();
ensure!(
sum > 0.0.into(),
"vsoftmax normalization sum must be positive"
);
let res: Vec<F> = exp_values.iter().map(|&a| a / sum).collect();
let res_up: Cow<[T]> = res.as_slice().up();
core.execute_store(rd_val, res_up.as_ref());
TRACER.lock().unwrap().post_vsoftmax::<F,T>(cores, data);
TRACER.lock().unwrap().post_vsoftmax::<F, T>(cores, data);
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn vmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
todo!()
}
#[inline(never)]
pub fn vrsu(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
todo!()
}
#[inline(never)]
pub fn vrsl(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
todo!()
}
@@ -684,6 +746,7 @@ pub fn vrsl(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
///////////////////////////////////////////////////////////////
///Communication/synchronization Instructions/////////////////
///////////////////////////////////////////////////////////////
#[inline(never)]
pub fn ld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_ld(cores, data);
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
@@ -700,6 +763,7 @@ pub fn ld(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn st(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_st(cores, data);
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
@@ -716,6 +780,7 @@ pub fn st(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn lldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_lldi(cores, data);
let (core, rd, imm) = data.get_core_rd_imm();
@@ -732,6 +797,7 @@ pub fn lldi(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_lmv(cores, data);
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
@@ -748,20 +814,32 @@ pub fn lmv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus>
Ok(InstructionStatus::Completed)
}
#[inline(never)]
pub fn isa_send(functor : usize) -> bool{
(send as *const () as usize) == functor
}
#[inline(never)]
pub fn send(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_send(cores, data);
Ok(InstructionStatus::Sending(data))
}
#[inline(never)]
pub fn isa_recv(functor : usize) -> bool{
(recv as *const () as usize) == functor
}
#[inline(never)]
pub fn recv(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
TRACER.lock().unwrap().pre_recv(cores, data);
Ok(InstructionStatus::Reciving(data))
}
#[inline(never)]
pub fn wait(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
Ok(InstructionStatus::Waiting(data))
}
#[inline(never)]
pub fn sync(cores: &mut CPU, data: InstructionData) -> Result<InstructionStatus> {
Ok(InstructionStatus::Sync(data))
}
@@ -14,7 +14,7 @@ pub mod helper;
#[derive(Clone, Copy, Debug)]
pub struct Instruction {
pub data: InstructionData,
functor: InstructionType,
pub functor: InstructionType,
}
#[derive(Debug, Clone, Copy, Default)]
@@ -567,7 +567,7 @@ fn json_to_send(
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_imm_core(core)
.set_imm_core(core + 1)
.set_imm_len(size)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
@@ -588,7 +588,7 @@ fn json_to_recv(
let (offset_select, offset_value) = json_to_offset(json.get("offset").unwrap());
inst_data_builder
.set_rd(rd)
.set_imm_core(core)
.set_imm_core(core + 1)
.set_imm_len(size)
.set_offset_select(offset_select)
.set_offset_value(offset_value);
@@ -1,45 +1,32 @@
use core::panic;
use std::collections::HashMap;
use serde_json::{Map, Value};
use serde_json::Value;
use std::{fs::File, io::BufReader};
use crate::{
CoreInstructionsBuilder, Executable,
cpu::{CPU, crossbar::{self, Crossbar}},
instruction_set::{
InstructionsBuilder,
instruction_data::{self, InstructionData, InstructionDataBuilder},
},
json_to_instruction::{self, json_isa},
memory_manager::type_traits::TryToUsize,
cpu::{CPU, crossbar::Crossbar},
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder},
json_to_instruction::json_isa,
};
pub fn json_to_executor<'a>(
pub fn json_to_executor<'a, 'b>(
config: Value,
mut cores: impl Iterator<Item = &'a Value>,
crossbars : Vec<Vec<&'a Crossbar>>
cores: &'b mut Vec<BufReader<File>>,
crossbars: Vec<Vec<&'a Crossbar>>,
) -> Executable<'a> {
let cell_precision = config.get("cell_precision").unwrap().as_i64().unwrap() as i32;
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32 - 1;
let xbar_count = config.get("xbar_array_count").unwrap().as_i64().unwrap() as i32;
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
let rows_crossbar = xbar_size[0].as_i64().unwrap() as i32;
let column_corssbar = xbar_size[1].as_i64().unwrap() as i32;
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as i32;
let mut cpu = CPU::new(core_cnt, crossbars);
let cpu = CPU::new(core_cnt, crossbars);
let mut core_insts_builder = CoreInstructionsBuilder::new(core_cnt as usize);
cores.next();
for core_indx in 1..=core_cnt {
for (external_core_indx, json_core_reader) in cores.iter_mut().enumerate() {
let core_indx = external_core_indx as i32 + 1;
let mut insts_builder = InstructionsBuilder::new();
let mut inst_data_builder = InstructionDataBuilder::new();
inst_data_builder.set_core_indx(core_indx).fix_core_indx();
let json_core = cores
.next()
.unwrap_or_else(|| panic!("cores files less than {}", core_indx ));
let json_core: Value = serde_json::from_reader(json_core_reader)
.unwrap_or_else(|err| panic!("failed to parse core{}: {}", external_core_indx, err));
let json_core_insts = json_core
.as_array()
.unwrap_or_else(|| panic!("core{} has not a list of instruction", core_indx));
.unwrap_or_else(|| panic!("core{} has not a list of instruction", external_core_indx));
for json_inst in json_core_insts {
json_isa::json_to_instruction(&mut insts_builder, &mut inst_data_builder, json_inst);
}
@@ -1,2 +1,2 @@
mod json_isa;
pub(crate) mod json_isa;
pub mod json_to_executor;
@@ -1,3 +1,4 @@
use std::cmp::min;
use std::fmt::Debug;
use anyhow::{Context, Result, bail, ensure};
@@ -86,7 +87,7 @@ where {
size,
};
if self.memory.len() < address + size {
self.memory.resize((address + size) * 2, 0);
self.memory.resize(min((address + size) * 2, u32::MAX as usize), 0);
}
self.load_requests.push(load_request);
Ok(self)
@@ -55,15 +55,23 @@ pub trait HasSigm {
impl HasSigm for f32 {
fn sigm(self) -> Self {
let ex = self.exp();
ex / (1.0 + ex)
if self >= 0.0 {
1.0 / (1.0 + (-self).exp())
} else {
let ex = self.exp();
ex / (1.0 + ex)
}
}
}
impl HasSigm for f64 {
fn sigm(self) -> Self {
let ex = self.exp();
ex / (1.0 + ex)
if self >= 0.0 {
1.0 / (1.0 + (-self).exp())
} else {
let ex = self.exp();
ex / (1.0 + ex)
}
}
}
@@ -1,14 +1,22 @@
#![allow(unused)]
use std::time::{Duration, SystemTime};
use anyhow::{Result, bail};
use std::{
collections::{HashMap, HashSet},
time::{Duration, SystemTime},
};
use crate::{
cpu::CPU,
instruction_set::{Instruction, InstructionStatus, Instructions, isa::functor_to_name},
instruction_set::{
Instruction, InstructionStatus, Instructions,
isa::{NAMES, functor_to_name, isa_recv, isa_send},
},
memory_manager::type_traits::TryToUsize,
send_recv::{SendRecv, handle_send_recv},
tracing::TRACER,
};
pub mod binary_to_instruction;
pub mod cpu;
pub mod instruction_set;
pub mod json_to_instruction;
@@ -80,6 +88,11 @@ pub struct Executable<'a> {
send_recv: SendRecv,
}
struct DeadlockInfo {
cycle: String,
states: String,
}
fn print_status(core_instructions: &[CoreInstructions]) {
let mut tot_instructions = 0;
let mut progress = 0;
@@ -111,7 +124,7 @@ impl<'a> Executable<'a> {
}
}
pub fn execute<'b>(&'b mut self)
pub fn execute<'b>(&'b mut self) -> Result<()>
where
'a: 'b,
{
@@ -144,8 +157,15 @@ impl<'a> Executable<'a> {
cpu_progressed = 0;
*program_counter += 1;
}
if (now.elapsed().unwrap() > Duration::from_secs(1)) {
print_status(&cores_instructions);
if (now.elapsed().unwrap() > Duration::from_secs(5)) {
print_status(cores_instructions);
if let Some(deadlock) = detect_deadlock(cores_instructions) {
bail!(
"Deadlock cycle detected: {} [{}]",
deadlock.cycle,
deadlock.states
);
}
now = SystemTime::now();
}
}
@@ -169,6 +189,24 @@ impl<'a> Executable<'a> {
}
}
print_status(cores_instructions);
if let Some(deadlock) = detect_deadlock(cores_instructions) {
bail!(
"Deadlock cycle detected: {} [{}]",
deadlock.cycle,
deadlock.states
);
}
if cores_instructions
.iter()
.any(|core_inst| core_inst.program_counter < core_inst.instructions.len())
{
bail!("Execution stalled with unfinished instructions");
}
#[cfg(feature = "profile_time")]
TRACER.lock().unwrap().report();
Ok(())
}
pub fn cpu(&self) -> &CPU<'a> {
@@ -190,6 +228,125 @@ impl<'a> Executable<'a> {
}
}
fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockInfo> {
#[derive(Debug, PartialEq, Eq)]
enum CoreState {
SendingTo(i32, i32),
ReceivingFrom(i32, i32),
Working,
Halted,
}
let mut states = HashMap::new();
for core_inst in cores_instructions.iter() {
if core_inst.program_counter >= core_inst.instructions.len() {
continue;
}
let Instruction { data, functor } = core_inst.instructions[core_inst.program_counter];
let functor_address = functor as usize;
let (this_core, target_core) = data.get_core_immcore();
if isa_recv(functor_address) {
states.insert(this_core, CoreState::ReceivingFrom(target_core, data.imm_len()));
} else if isa_send(functor_address) {
states.insert(this_core, CoreState::SendingTo(target_core, data.imm_len()));
} else {
states.insert(this_core, CoreState::Working);
}
}
let mut wait_for = HashMap::new();
for (&core_id, state) in states.iter() {
match state {
CoreState::SendingTo(target_core, size) => {
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
if target_state != &CoreState::ReceivingFrom(core_id, *size) {
wait_for.insert(core_id, *target_core);
}
}
CoreState::ReceivingFrom(target_core, size) => {
let target_state = states.get(target_core).unwrap_or(&CoreState::Halted);
if target_state != &CoreState::SendingTo(core_id, *size) {
wait_for.insert(core_id, *target_core);
}
}
CoreState::Working | CoreState::Halted => {
}
}
}
let mut visited = HashSet::new();
for &start_core in wait_for.keys() {
if visited.contains(&start_core) {
continue;
}
let mut path = Vec::new();
let mut current_core = start_core;
let mut in_path = HashSet::new();
while let Some(&waiting_for) = wait_for.get(&current_core) {
path.push(current_core);
in_path.insert(current_core);
visited.insert(current_core);
// Found a closed loop!
if in_path.contains(&waiting_for) {
let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap();
let cycle = &path[cycle_start..];
let format_core = |core: &i32| (core - 1).to_string();
let cycle_str = cycle
.iter()
.map(format_core)
.collect::<Vec<_>>()
.join(" -> ");
let cycle = cycle
.iter()
.copied()
.chain(std::iter::once(waiting_for))
.collect::<Vec<_>>();
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for - 1);
let states_msg = cycle
.iter()
.filter_map(|core| {
states.get(core).map(|state| match state {
CoreState::SendingTo(target, size) => {
format!("core {} send {}B -> {}", core - 1, size, target - 1)
}
CoreState::ReceivingFrom(source, size) => {
format!("core {} recv {}B <- {}", core - 1, size, source - 1)
}
CoreState::Working => format!("core {} working", core - 1),
CoreState::Halted => format!("core {} halted", core - 1),
})
})
.collect::<Vec<_>>()
.join(", ");
return Some(DeadlockInfo {
cycle: cycle_msg,
states: states_msg,
});
}
// Hit a known branch that didn't result in a cycle
if visited.contains(&waiting_for) {
break;
}
current_core = waiting_for;
}
}
None
}
fn handle_wait_sync<'a, 'b, 'c>(
cpu: &'b mut CPU<'a>,
core_instructions: &'c mut [CoreInstructions],
@@ -58,6 +58,20 @@ where 'a : 'b
&& sender.internal_core == receiver.external_core
&& receiver.internal_core == sender.external_core
{
{
let sender = &mut core_instructions[sender.internal_core];
let pc = sender.program_counter;
let inst = sender.instructions.get(pc).unwrap();
let data = inst.data;
TRACER.lock().unwrap().pre_send(cpu, data);
}
{
let recv = &mut core_instructions[receiver.internal_core];
let pc = recv.program_counter;
let inst = recv.instructions.get(pc).unwrap();
let data = inst.data;
TRACER.lock().unwrap().pre_recv(cpu, data);
}
let [sender_core, reciver_core] =
cpu.get_multiple_cores([sender.internal_core, receiver.internal_core]);
let memory = sender_core
@@ -13,7 +13,7 @@ use crate::{
};
use std::io::Write;
#[cfg(not(feature = "tracing"))]
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
impl Trace {
///////////////////////////////////////////////////////////////
/////////////////Scalar/register Instructions//////////////////
@@ -1,52 +1,32 @@
mod tracing_isa;
mod disable;
mod pretty_print;
use std::{fs::File, path::{ PathBuf}};
#[cfg(feature = "profile_time")]
mod profile;
#[cfg(feature = "profile_time")]
use profile::Trace;
#[cfg(feature = "tracing")]
mod trace;
#[cfg(feature = "tracing")]
use trace::Trace;
use crate::Executable;
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
use std::path::PathBuf;
use std::sync::{LazyLock, Mutex};
use crate::Executable;
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
pub struct Trace {}
#[cfg(feature = "tracing")]
pub struct Trace {
out_files : Vec<File>
}
#[cfg(feature = "tracing")]
#[cfg(not(any(feature = "tracing", feature = "profile_time")))]
impl Trace {
fn new() -> Self {
Self { out_files : Vec::new()}
Self {}
}
pub fn init(&mut self, num_core : usize , mut path : PathBuf) {
path.pop();
for i in 0..num_core {
path.push(format!("TraceCore{}", i));
let file = File::create(&path).expect("Can not create file");
self.out_files.push(file);
path.pop();
}
}
}
#[cfg(not(feature = "tracing"))]
pub struct Trace {
pub fn init(&mut self, num_core: usize, path: PathBuf) {}
}
#[cfg(not(feature = "tracing"))]
impl Trace {
fn new() -> Self {
Self { }
}
pub fn init(&mut self, num_core : usize, path : PathBuf ) {
}
}
pub static TRACER: LazyLock<Mutex<Trace>> = LazyLock::new(|| { Trace::new().into()});
pub static TRACER: LazyLock<Mutex<Trace>> = LazyLock::new(|| Trace::new().into());
@@ -0,0 +1,73 @@
use std::{collections::HashMap, path::PathBuf, time::Instant};
use crate::tracing::profile::profile_analysis::{
analyze_timings, generate_interactive_report, print_textual_report,
};
pub mod profile_analysis;
pub mod profile_isa;
pub struct Trace {
instruction_times: HashMap<String, Vec<(u128,u128)>>,
core_start_time: HashMap<usize, Option<Instant>>,
start_time: Instant,
}
impl Trace {
pub fn new() -> Self {
let mut instruction_times = HashMap::new();
instruction_times.insert("sldi".to_string(), Vec::with_capacity(20000));
instruction_times.insert("sld".to_string(), Vec::with_capacity(20000));
instruction_times.insert("sadd".to_string(), Vec::with_capacity(20000));
instruction_times.insert("ssub".to_string(), Vec::with_capacity(20000));
instruction_times.insert("smul".to_string(), Vec::with_capacity(20000));
instruction_times.insert("saddi".to_string(), Vec::with_capacity(20000));
instruction_times.insert("smuli".to_string(), Vec::with_capacity(20000));
instruction_times.insert("setbw".to_string(), Vec::with_capacity(20000));
instruction_times.insert("mvmul".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvadd".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvsub".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvmul".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvdmul".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvmax".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvsll".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vvsra".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vavg".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vrelu".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vtanh".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vsigm".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vsoftmax".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vmv".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vrsu".to_string(), Vec::with_capacity(20000));
instruction_times.insert("vrsl".to_string(), Vec::with_capacity(20000));
instruction_times.insert("ld".to_string(), Vec::with_capacity(20000));
instruction_times.insert("st".to_string(), Vec::with_capacity(20000));
instruction_times.insert("lldi".to_string(), Vec::with_capacity(20000));
instruction_times.insert("lmv".to_string(), Vec::with_capacity(20000));
instruction_times.insert("send".to_string(), Vec::with_capacity(20000));
instruction_times.insert("recv".to_string(), Vec::with_capacity(20000));
instruction_times.insert("wait".to_string(), Vec::with_capacity(20000));
instruction_times.insert("sync".to_string(), Vec::with_capacity(20000));
Self {
instruction_times,
core_start_time: HashMap::new(),
start_time: Instant::now()
}
}
pub fn init(&mut self, num_core: usize, path: PathBuf) {
for i in 0..num_core {
self.core_start_time.insert(i, None);
}
}
pub fn report(&self) {
let res = analyze_timings(&self.instruction_times);
print_textual_report(&res);
generate_interactive_report(
&self.instruction_times,
&["mvmul", "recv"],
"/tmp/report.html",
);
}
}
@@ -0,0 +1,192 @@
use comfy_table::{Cell, Table, modifiers::UTF8_ROUND_CORNERS, presets::UTF8_FULL};
use statrs::statistics::{Data, Distribution, Max, Min, OrderStatistics};
use std::collections::HashMap;
#[derive(Debug)]
pub struct InstructionStats {
pub name: String,
pub count: usize,
pub total_time: u128,
pub min: f64,
pub max: f64,
pub mean: f64,
pub median: f64,
pub std_dev: f64,
pub cv: f64,
pub p95: f64,
pub p99: f64,
pub skewness: f64,
pub kurtosis: f64,
}
fn format_time(ns: f64) -> String {
if ns.is_nan() {
return "NaN".to_string();
}
if ns >= 1_000_000_000.0 {
format!("{:.2} s", ns / 1_000_000_000.0)
} else if ns >= 1_000_000.0 {
format!("{:.2} ms", ns / 1_000_000.0)
} else if ns >= 1_000.0 {
format!("{:.2} µs", ns / 1_000.0)
} else {
format!("{:.2} ns", ns)
}
}
fn calculate_skewness_kurtosis(times: &[f64], mean: f64, std_dev: f64) -> (f64, f64) {
let n = times.len() as f64;
if n < 4.0 || std_dev == 0.0 {
return (f64::NAN, f64::NAN);
}
let mut sum_m3 = 0.0;
let mut sum_m4 = 0.0;
for &x in times {
let deviation = x - mean;
sum_m3 += deviation.powi(3);
sum_m4 += deviation.powi(4);
}
let m3 = sum_m3 / n;
let m4 = sum_m4 / n;
let skewness = m3 / std_dev.powi(3);
let kurtosis = (m4 / std_dev.powi(4)) - 3.0;
(skewness, kurtosis)
}
pub fn analyze_timings(timings: &HashMap<String, Vec<(u128, u128)>>) -> Vec<InstructionStats> {
let mut results = Vec::new();
for (instruction, times) in timings {
let count = times.len();
if count == 0 {
continue;
}
// Extract ONLY the duration (the second element of the tuple) for stats
let durations: Vec<u128> = times.iter().map(|&(_, duration)| duration).collect();
let total_time: u128 = durations.iter().sum();
let f64_times: Vec<f64> = durations.iter().map(|&t| t as f64).collect();
let mut data = Data::new(f64_times.clone());
let mean = data.mean().unwrap_or(0.0);
let std_dev = data.std_dev().unwrap_or(0.0);
let cv = if mean > 0.0 { std_dev / mean } else { 0.0 };
let (skewness, kurtosis) = calculate_skewness_kurtosis(&f64_times, mean, std_dev);
results.push(InstructionStats {
name: instruction.clone(),
count,
total_time,
min: data.min(),
max: data.max(),
mean,
median: data.median(),
std_dev,
cv,
p95: data.percentile(95),
p99: data.percentile(99),
skewness,
kurtosis,
});
}
results.sort_by(|a, b| b.mean.partial_cmp(&a.mean).unwrap());
results
}
pub fn print_textual_report(stats: &[InstructionStats]) {
let mut table = Table::new();
table
.load_preset(UTF8_FULL)
.apply_modifier(UTF8_ROUND_CORNERS)
.set_header(vec![
"Instruction",
"Count",
"Total Time",
"Mean",
"Median",
"Min",
"Max",
"P95",
"P99",
"StdDev",
"CV",
"Skewness",
"Kurtosis",
]);
for stat in stats {
table.add_row(vec![
Cell::new(&stat.name),
Cell::new(stat.count.to_string()),
Cell::new(format_time(stat.total_time as f64)), // Cast u128 to f64 for formatting
Cell::new(format_time(stat.mean)),
Cell::new(format_time(stat.median)),
Cell::new(format_time(stat.min)),
Cell::new(format_time(stat.max)),
Cell::new(format_time(stat.p95)),
Cell::new(format_time(stat.p99)),
Cell::new(format_time(stat.std_dev)),
Cell::new(format!("{:.3}", stat.cv)),
Cell::new(format!("{:.2}", stat.skewness)),
Cell::new(format!("{:.2}", stat.kurtosis)),
]);
}
println!("{table}");
}
pub fn generate_interactive_report(
timings: &HashMap<String, Vec<(u128, u128)>>,
instructions_to_plot: &[&str], // <-- NEW: Only plot these
file_path: &str,
) {
use plotly::common::{Mode, Marker, Line};
use plotly::layout::{Axis, Layout};
use plotly::{Plot, Scatter};
use std::collections::HashMap;
let mut plot = Plot::new();
for &instruction_name in instructions_to_plot {
// Only proceed if the instruction exists in our timings map
if let Some(times) = timings.get(instruction_name) {
let x_axis: Vec<f64> = times.iter().map(|&(ts, _)| ts as f64).collect();
let y_axis: Vec<f64> = times.iter().map(|&(_, dur)| dur as f64).collect();
let text_array: Vec<String> = times.iter()
.map(|&(_, dur)| format_time(dur as f64))
.collect();
let trace = Scatter::new(x_axis, y_axis)
.name(instruction_name)
.mode(Mode::LinesMarkers)
.marker(Marker::new().size(4).opacity(0.6))
.line(Line::new().width(1.0))
.text_array(text_array)
.hover_info(plotly::common::HoverInfo::All);
plot.add_trace(trace);
}
}
let layout = Layout::new()
.title(plotly::common::Title::new("Simulator Timeline: Top Offenders"))
.x_axis(Axis::new().title(plotly::common::Title::new("Absolute Time (ns)")))
.y_axis(Axis::new().title(plotly::common::Title::new("Execution Duration")));
plot.set_layout(layout);
plot.write_html(file_path);
println!("🌐 Interactive timeline saved to {}", file_path);
}
@@ -0,0 +1,364 @@
use crate::{
cpu::CPU,
instruction_set::instruction_data::InstructionData,
memory_manager::{
MemoryStorable,
type_traits::{FromFloat, UpcastDestTraits, UpcastSlice},
},
tracing::Trace,
utility::{add_offset_r1, add_offset_rd},
};
use std::io::Write;
use std::time::Instant;
#[cfg(feature = "profile_time")]
impl Trace {
///////////////////////////////////////////////////////////////
/////////////////Scalar/register Instructions//////////////////
///////////////////////////////////////////////////////////////
fn pre_impl(&mut self, cores: &mut CPU, data: InstructionData) {
let (core_indx, rd, imm) = data.get_core_rd_imm();
let core_indx = core_indx as usize;
if self.core_start_time.get(&core_indx).unwrap().is_none() {
self.core_start_time.insert(core_indx, Some(Instant::now()));
}
}
fn post_impl(&mut self, cores: &mut CPU, data: InstructionData, name: &'static str) {
let (core_indx, rd, imm) = data.get_core_rd_imm();
let core_indx = core_indx as usize;
let Self {
instruction_times,
core_start_time,
start_time,
} = self;
let now = Instant::now();
instruction_times
.get_mut(name)
.unwrap()
.push((now.duration_since(*start_time).as_nanos(), now.duration_since(core_start_time[&core_indx].unwrap()).as_nanos()));
self.core_start_time.insert(core_indx, None);
}
pub fn pre_sldi(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_sldi(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "sldi");
}
pub fn pre_sld(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_sld(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "sld");
}
pub fn pre_sadd(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_sadd(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "sadd");
}
pub fn pre_ssub(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_ssub(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "ssub");
}
pub fn pre_smul(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_smul(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "smul");
}
pub fn pre_saddi(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_saddi(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "saddi");
}
pub fn pre_smuli(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_smuli(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "smuli");
}
/////////////////////////////////////////////////////////////////
///////////////////Matrix/vector Instructions////////////////////
/////////////////////////////////////////////////////////////////
pub fn pre_setbw(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_setbw(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "setbw");
}
pub fn pre_mvm<F, M, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T> + UpcastSlice<M>,
[M]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_mvm<F, M, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T> + UpcastSlice<M>,
[M]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "mvmul");
}
pub fn pre_vvadd<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvadd<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvadd");
}
pub fn pre_vvsub<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvsub<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvsub");
}
pub fn pre_vvmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvmul");
}
pub fn pre_vvdmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvdmul<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvdmul");
}
pub fn pre_vvmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vvmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vvmax");
}
pub fn pre_vavg<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.pre_impl(cores, data);
}
pub fn post_vavg<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
self.post_impl(cores, data, "vavg");
}
pub fn pre_vrelu<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.pre_impl(cores, data);
}
pub fn post_vrelu<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.post_impl(cores, data, "vrelu");
}
pub fn pre_vtanh<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.pre_impl(cores, data);
}
pub fn post_vtanh<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.post_impl(cores, data, "vtanh");
}
pub fn pre_vsigm<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.pre_impl(cores, data);
}
pub fn post_vsigm<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.post_impl(cores, data, "vsigm");
}
pub fn pre_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.pre_impl(cores, data);
}
pub fn post_vsoftmax<F, T>(&mut self, cores: &mut CPU, data: InstructionData)
where
[F]: UpcastSlice<T>,
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable + From<f32>,
{
self.post_impl(cores, data, "vsoftmax");
}
/////////////////////////////////////////////////////////////////
/////Communication/synchronization Instructions/////////////////
/////////////////////////////////////////////////////////////////
pub fn pre_ld(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_ld(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "ld");
}
pub fn pre_st(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_st(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "st");
}
pub fn pre_lldi(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_lldi(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "lldi");
}
pub fn pre_lmv(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_lmv(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "lmv");
}
pub fn pre_send(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_send(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "send");
}
pub fn pre_recv(&mut self, cores: &mut CPU, data: InstructionData) {
self.pre_impl(cores, data);
}
pub fn post_recv(&mut self, cores: &mut CPU, data: InstructionData) {
self.post_impl(cores, data, "recv");
}
}
@@ -0,0 +1,28 @@
use std::{fs::File, path::PathBuf};
pub mod pretty_print;
pub mod tracing_isa;
pub struct Trace {
out_files: Vec<File>,
}
impl Trace {
pub fn new() -> Self {
Self {
out_files: Vec::new(),
}
}
pub fn init(&mut self, num_core: usize, mut path: PathBuf) {
path.pop();
for i in 0..num_core {
path.push(format!("TraceCore{}", i));
let file = File::create(&path).expect("Can not create file");
self.out_files.push(file);
path.pop();
}
}
}
@@ -1,4 +1,4 @@
use crate::tracing::pretty_print;
use crate::{tracing::trace::pretty_print, utility::add_offset_r2};
use std::fs::File;
use crate::{
@@ -13,7 +13,6 @@ use crate::{
};
use std::io::Write;
#[cfg(feature = "tracing")]
impl Trace {
///////////////////////////////////////////////////////////////
/////////////////Scalar/register Instructions//////////////////
@@ -284,7 +283,6 @@ impl Trace {
M: UpcastDestTraits<M> + MemoryStorable + FromFloat,
F: UpcastDestTraits<F> + MemoryStorable,
{
use crate::tracing::pretty_print;
let (core_indx, rd, r1, mbiw, relu, group) = data.get_core_rd_r1_mbiw_immrelu_immgroup();
let file: &mut File = self
@@ -358,8 +356,6 @@ impl Trace {
T: UpcastDestTraits<T> + MemoryStorable,
F: UpcastDestTraits<F> + MemoryStorable,
{
use crate::{tracing::pretty_print, utility::add_offset_r2};
let (core_indx, rd, r1, r2, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let file: &mut File = self
@@ -990,8 +986,6 @@ impl Trace {
/////////////////////////////////////////////////////////////////
pub fn ld_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
use crate::tracing::pretty_print;
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let file: &mut File = self
@@ -1044,8 +1038,6 @@ impl Trace {
}
pub fn st_impl(&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
use crate::tracing::pretty_print;
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
let file: &mut File = self
@@ -1138,7 +1130,6 @@ impl Trace {
}
fn lmv_impl (&mut self, cores: &mut CPU, data: InstructionData, prefix: &'static str) {
use crate::tracing::pretty_print;
let (core, rd, r1, _, imm_len, offset_select, offset_value) =
data.get_core_rd_r1_r2_immlen_offset();
@@ -1,7 +1,45 @@
use anyhow::{Result,Context};
use std::{fmt::Debug, mem::transmute};
use crate::memory_manager::type_traits::TryToUsize;
pub trait AddressArg {
fn to_address_usize(self) -> Result<usize>;
}
impl AddressArg for usize {
fn to_address_usize(self) -> Result<usize> {
Ok(self)
}
}
impl AddressArg for u32 {
fn to_address_usize(self) -> Result<usize> {
Ok(self as usize)
}
}
impl AddressArg for u64 {
fn to_address_usize(self) -> Result<usize> {
usize::try_from(self).context("address does not fit in usize")
}
}
impl AddressArg for i32 {
fn to_address_usize(self) -> Result<usize> {
Ok(self as u32 as usize)
}
}
impl AddressArg for i64 {
fn to_address_usize(self) -> Result<usize> {
usize::try_from(self).context("address can not be negative")
}
}
fn address_to_usize(address: i32) -> usize {
address as u32 as usize
}
fn add_offset_impl(address: usize, offset_select : i32, offset_value : i32, id:i32) -> usize{
assert!(offset_select == 1 || offset_select == 2 || offset_select == 4 || offset_value == 0, "offset_select not a bit field");
@@ -14,21 +52,21 @@ fn add_offset_impl(address: usize, offset_select : i32, offset_value : i32, id:i
}
pub fn add_offset_rd(address: impl TryToUsize, offset_select : i32, offset_value : i32) -> usize
pub fn add_offset_rd(address: i32, offset_select : i32, offset_value : i32) -> usize
{
let address = address.try_into().expect("address can not be negative");
let address = address_to_usize(address);
add_offset_impl(address, offset_select, offset_value, 4)
}
pub fn add_offset_r1(address: impl TryToUsize, offset_select : i32, offset_value : i32) -> usize
pub fn add_offset_r1(address: i32, offset_select : i32, offset_value : i32) -> usize
{
let address = address.try_into().expect("address can not be negative");
let address = address_to_usize(address);
add_offset_impl(address, offset_select, offset_value, 1)
}
pub fn add_offset_r2(address: impl TryToUsize, offset_select : i32, offset_value : i32) -> usize
pub fn add_offset_r2(address: i32, offset_select : i32, offset_value : i32) -> usize
{
let address = address.try_into().expect("address can not be negative");
let address = address_to_usize(address);
add_offset_impl(address, offset_select, offset_value, 2)
}
@@ -1,6 +1,11 @@
use std::path::Path;
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
use pimcore::{
Executable,
cpu::crossbar::Crossbar,
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
memory_manager::CoreMemory,
};
fn simple_read(path: &Path) -> Vec<f32> {
if !path.exists() {
@@ -17,14 +22,12 @@ fn simple_read(path: &Path) -> Vec<f32> {
fn mvmul_f32(err: &str)
where
{
let mut cpu = CPU::new(0);
cpu.reserve_crossbar(1, 1024 * size_of::<f32>(), 1024);
let (memory, crossbars) = cpu.host().get_memory_crossbar();
let matrix = simple_read(Path::new("B.txt")) ;
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
let vector = simple_read(Path::new("A.txt"));
let matrix = simple_read(Path::new("tests/B.txt"));
let mut crossbar = Crossbar::new(1024 * size_of::<f32>(), 1024, CoreMemory::new());
crossbar.execute_store(&matrix).unwrap();
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
let (memory, _) = cpu.host().get_memory_crossbar();
let vector = simple_read(Path::new("tests/A.txt"));
memory.execute_store(0, &vector).unwrap();
let mut inst_builder = InstructionsBuilder::new();
@@ -57,7 +60,7 @@ where
.cpu_mut()
.host()
.load::<f32>(1024 * size_of::<f32>(), 1024*size_of::<f32>()).unwrap()[0].iter().zip(
simple_read(Path::new("X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
simple_read(Path::new("tests/X.txt")) ).all(|(&a,b) : (&f32, f32)| {a-b < 0.001}),
"Wrong result for {}",
err
);
@@ -69,5 +72,3 @@ fn mvmul_big_test() {
}
@@ -0,0 +1,5 @@
use pimcore::cpu::CPU;
pub fn empty_cpu(num_cores: usize) -> CPU<'static> {
CPU::new(num_cores, vec![Vec::new(); num_cores + 1])
}
@@ -1,51 +1,103 @@
use std::{fs, io::BufReader, path::Path};
use std::{
fs::{self, File},
io::BufReader,
path::{Path, PathBuf},
};
use anyhow::{Context, Result};
use pimcore::json_to_instruction::json_to_executor;
use pimcore::{
cpu::crossbar::Crossbar,
json_to_instruction::json_to_executor,
memory_manager::CoreMemory,
};
use serde_json::Value;
fn collect_json_from_subfolders<P: AsRef<Path>>(root: P) -> Result<Vec<(Value, Vec<Value>)>> {
fn collect_examples<P: AsRef<Path>>(root: P) -> Result<Vec<PathBuf>> {
let mut result = Vec::new();
for entry in fs::read_dir(root)? {
let entry = entry.context("Root not found")?;
let path = entry.path();
if path.is_dir() {
let mut cores = Vec::new();
let mut config: Option<Value> = None;
for sub_entry in fs::read_dir(&path)
.with_context(|| format!("File {} not readable", path.display()))?
{
let sub_entry =
sub_entry.with_context(|| format!("File {} not readable", path.display()))?;
let sub_path = sub_entry.path();
if sub_path.is_file()
&& sub_path.extension().and_then(|s| s.to_str()) == Some("json")
{
let file = fs::File::open(&sub_path)
.with_context(|| format!("Subpath {} not opened", sub_path.display()))?;
let reader = BufReader::new(file);
let val: Value = serde_json::from_reader(reader).with_context(|| format!(
"Serde reader fail for subpath {}",
sub_path.display()
))?;
if sub_path.file_name().unwrap() == "config.json" {
config = Some(val);
} else {
cores.push(val);
}
}
}
result.push((config.unwrap(), cores));
result.push(path);
}
}
Ok(result)
}
fn core_sort_key(path: &Path) -> i32 {
let stem = path.file_stem().unwrap().to_str().unwrap();
stem[5..].parse::<i32>().unwrap()
}
fn crossbar_sort_key(path: &Path) -> i32 {
let stem = path.file_stem().unwrap().to_str().unwrap();
stem[9..].parse::<i32>().unwrap()
}
fn load_crossbars(folder: &Path, config: &Value) -> Result<Vec<Vec<Crossbar>>> {
let xbar_size = config.get("xbar_size").unwrap().as_array().unwrap();
let rows = xbar_size[0].as_i64().unwrap() as usize;
let cols = xbar_size[1].as_i64().unwrap() as usize;
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
let mut owned_crossbars = Vec::with_capacity(core_cnt + 1);
owned_crossbars.push(Vec::new());
for core_idx in 0..core_cnt {
let core_folder = folder.join(format!("core_{core_idx}"));
let mut core_crossbars = Vec::new();
if core_folder.is_dir() {
let mut paths: Vec<_> = fs::read_dir(&core_folder)?
.map(|entry| entry.map(|entry| entry.path()))
.collect::<std::io::Result<Vec<_>>>()?;
paths.sort_by_cached_key(|path| crossbar_sort_key(path));
for path in paths {
if path.extension().and_then(|ext| ext.to_str()) != Some("bin") {
continue;
}
let bytes = fs::read(&path)
.with_context(|| format!("failed to read crossbar {}", path.display()))?;
let mut crossbar = Crossbar::new(cols * 4, rows, CoreMemory::new());
crossbar.execute_store(&bytes)?;
core_crossbars.push(crossbar);
}
}
owned_crossbars.push(core_crossbars);
}
Ok(owned_crossbars)
}
#[test]
fn json_folder_tester() {
let examples = collect_json_from_subfolders("data").unwrap();
for example in examples {
let (config, cores) = example;
json_to_executor::json_to_executor(config, cores.iter()).execute();
let examples = collect_examples("tests/data").unwrap();
for folder in examples {
let config_path = folder.join("config.json");
let config_file = File::open(&config_path).unwrap();
let config: Value = serde_json::from_reader(BufReader::new(config_file)).unwrap();
let core_cnt = config.get("core_cnt").unwrap().as_i64().unwrap() as usize;
let mut core_paths: Vec<_> = fs::read_dir(&folder)
.unwrap()
.map(|entry| entry.unwrap().path())
.filter(|path| path.extension().and_then(|ext| ext.to_str()) == Some("json"))
.filter(|path| path.file_name().unwrap() != "config.json")
.collect();
core_paths.sort_by_cached_key(|path| core_sort_key(path));
assert_eq!(core_paths.len(), core_cnt);
let mut core_readers: Vec<_> = core_paths
.into_iter()
.map(|path| BufReader::new(File::open(path).unwrap()))
.collect();
let owned_crossbars = load_crossbars(&folder, &config).unwrap();
let crossbars = owned_crossbars
.iter()
.map(|core_crossbars| core_crossbars.iter().collect())
.collect();
let mut executable = json_to_executor::json_to_executor(config, &mut core_readers, crossbars);
let memory = fs::read(folder.join("memory.bin")).unwrap();
executable.cpu_mut().host().execute_store(0, &memory).unwrap();
executable.execute();
}
}
@@ -1,11 +1,17 @@
mod common;
use pimcore::{Executable, cpu::CPU, instruction_set::{InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*}};
use pimcore::{
Executable,
instruction_set::{
InstructionType, InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*,
},
};
#[test]
#[should_panic(expected = "Function not found for the requested size") ]
fn wrong_size_place_holder() {
let cpu = CPU::new(0);
let cpu = common::empty_cpu(0);
let mut inst_builder = InstructionsBuilder::new();
let mut idata_build = InstructionDataBuilder::new();
idata_build.set_core_indx(0).fix_core_indx();
@@ -30,7 +36,7 @@ fn wrong_size_place_holder() {
fn place_holder(inst : InstructionType) {
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let mut idata_build = InstructionDataBuilder::new();
idata_build.set_core_indx(0).fix_core_indx();
inst(&mut cpu, idata_build.build()).unwrap();
@@ -1,8 +1,10 @@
mod common;
use pimcore::{
Executable,
cpu::CPU,
cpu::crossbar::Crossbar,
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
memory_manager::{MemoryStorable, type_traits::UpcastDestTraits},
memory_manager::{CoreMemory, MemoryStorable, type_traits::UpcastDestTraits},
};
/// VVADD Test
@@ -11,7 +13,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
1.0.into(),
2.0.into(),
@@ -115,7 +117,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
1.0.into(),
2.0.into(),
@@ -219,7 +221,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
1.0.into(),
2.0.into(),
@@ -323,7 +325,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
1.0.into(),
2.0.into(),
@@ -420,7 +422,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
9.0.into(),
2.0.into(),
@@ -524,7 +526,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
9.0.into(),
2.0.into(),
@@ -562,6 +564,7 @@ where
vavg,
idata_build
.set_rdr1r2(3, 1, 1)
.set_offset_select(1)
.set_imm_len(8 * size_of::<F>() as i32)
.build(),
);
@@ -617,7 +620,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
(-9.0).into(),
2.0.into(),
@@ -717,7 +720,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
0.1.into(),
0.2.into(),
@@ -819,7 +822,7 @@ where
F: From<f32> + std::fmt::Debug + PartialEq<F> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
{
let mut cpu = CPU::new(0);
let mut cpu = common::empty_cpu(0);
let buff: [F; _] = [
0.1.into(),
0.2.into(),
@@ -923,9 +926,6 @@ where
M: From<f32> + std::fmt::Debug + PartialEq<M> + MemoryStorable,
T: From<f32> + std::fmt::Debug + PartialEq<T> + MemoryStorable + UpcastDestTraits<T>,
{
let mut cpu = CPU::new(0);
cpu.reserve_crossbar(1, 4 * size_of::<M>(), 4);
let (memory, crossbars) = cpu.host().get_memory_crossbar();
let matrix: [M; _] = [
1.0.into(),
2.0.into(),
@@ -944,7 +944,10 @@ where
15.0.into(),
16.0.into(),
];
crossbars.get_mut(0).unwrap().execute_store( &matrix).unwrap();
let mut crossbar = Crossbar::new(4 * size_of::<M>(), 4, CoreMemory::new());
crossbar.execute_store(&matrix).unwrap();
let mut cpu = pimcore::cpu::CPU::new(0, vec![vec![&crossbar]]);
let (memory, _) = cpu.host().get_memory_crossbar();
let vector: [F; _] = [
1.0.into(),
2.0.into(),
@@ -1054,5 +1057,3 @@ fn mvmul_test() {
mvmul_test_generic::<f64,f64,f64>("mvmul<f64,f64,f64>",1);
}
@@ -1,12 +1,13 @@
mod common;
use pimcore::{
Executable, CoreInstructionsBuilder,
cpu::CPU,
instruction_set::{InstructionsBuilder, instruction_data::InstructionDataBuilder, isa::*},
};
#[test]
fn ld_test() {
let mut cpu = CPU::new(1);
let mut cpu = common::empty_cpu(1);
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
let buff: [f32; _] = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
@@ -41,7 +42,7 @@ fn ld_test() {
#[test]
fn st_test() {
let mut cpu = CPU::new(1);
let mut cpu = common::empty_cpu(1);
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
let buff: [f32; _] = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
@@ -76,7 +77,7 @@ fn st_test() {
#[test]
fn lldi_test() {
let cpu = CPU::new(1);
let cpu = common::empty_cpu(1);
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
let mut inst_builder = InstructionsBuilder::new();
let mut idata_build = InstructionDataBuilder::new();
@@ -106,7 +107,7 @@ fn lldi_test() {
#[test]
fn lmv_test() {
let mut cpu = CPU::new(1);
let mut cpu = common::empty_cpu(1);
let mut core_instruction_builder = CoreInstructionsBuilder::new(1);
let buff: [f32; _] = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
@@ -148,7 +149,7 @@ fn lmv_test() {
#[test]
fn simple_send_recv_test() {
let mut cpu = CPU::new(2);
let mut cpu = common::empty_cpu(2);
let mut core_instruction_builder = CoreInstructionsBuilder::new(2);
let buff: [f32; _] = [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0,
@@ -207,7 +208,7 @@ fn simple_send_recv_test() {
#[test]
fn multiple_send_recv_test() {
let mut cpu = CPU::new(4);
let mut cpu = common::empty_cpu(4);
let mut core_instruction_builder = CoreInstructionsBuilder::new(4);
let buff: [f32; _] = [
1.0, 1.0, 1.0, 1.0, 1.0
@@ -226,7 +227,7 @@ fn multiple_send_recv_test() {
];
cpu.core(4).execute_store(0, &buff).unwrap();
let send_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, inst_builder: &mut InstructionsBuilder, from : i32, to : i32| {
let send_inst = |inst_builder: &mut InstructionsBuilder, from: i32, to: i32| {
let mut idata_build = InstructionDataBuilder::new();
idata_build.set_core_indx(from).fix_core_indx();
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
@@ -240,7 +241,7 @@ fn multiple_send_recv_test() {
);
};
let recv_inst = |cpu :&mut CPU, core_instruction_builder: &mut CoreInstructionsBuilder, mut inst_builder: &mut InstructionsBuilder, to : i32, from : i32| {
let recv_inst = |inst_builder: &mut InstructionsBuilder, to: i32, from: i32| {
let mut idata_build = InstructionDataBuilder::new();
idata_build.set_core_indx(to).fix_core_indx();
inst_builder.make_inst(sldi, idata_build.set_rdimm(1, from*size_of::<f32>() as i32).build());
@@ -257,26 +258,26 @@ fn multiple_send_recv_test() {
// 1 -> 3
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,1, 3);
send_inst(&mut inst_builder, 1, 3);
core_instruction_builder.set_core(1, inst_builder.build());
// 2 -> 3
// 2 <- 4
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 3);
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,2, 4);
send_inst(&mut inst_builder, 2, 3);
recv_inst(&mut inst_builder, 2, 4);
core_instruction_builder.set_core(2, inst_builder.build());
// 3 <- 2
// 3 <- 4
// 3 <- 1
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 2);
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 4);
recv_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,3, 1);
recv_inst(&mut inst_builder, 3, 2);
recv_inst(&mut inst_builder, 3, 4);
recv_inst(&mut inst_builder, 3, 1);
core_instruction_builder.set_core(3, inst_builder.build());
// 4 -> 2
// 4 -> 3
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 2);
send_inst(&mut cpu,&mut core_instruction_builder,&mut inst_builder,4, 3);
send_inst(&mut inst_builder, 4, 2);
send_inst(&mut inst_builder, 4, 3);
core_instruction_builder.set_core(4, inst_builder.build());
let mut executable = Executable::new(cpu, core_instruction_builder.build());
+56
View File
@@ -10,6 +10,56 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
set(PIM_GENERATED_PATH_SHIM_TARGET "")
get_filename_component(PIM_BIN_ROOT_NAME "${PIM_BIN_ROOT}" NAME)
if (PIM_BIN_ROOT_NAME STREQUAL "raptor-external")
get_filename_component(PIM_GENERATED_PATH_SHIM_ROOT "${PIM_BIN_ROOT}" DIRECTORY)
set(PIM_GENERATED_PATH_SHIM_OUTPUTS)
function(add_pim_generated_path_shim relative_path)
set(real_file "${PIM_BIN_ROOT}/${relative_path}")
set(shim_file "${PIM_GENERATED_PATH_SHIM_ROOT}/${relative_path}")
get_filename_component(shim_dir "${shim_file}" DIRECTORY)
add_custom_command(
OUTPUT "${shim_file}"
DEPENDS "${real_file}"
COMMAND "${CMAKE_COMMAND}" -E make_directory "${shim_dir}"
COMMAND "${CMAKE_COMMAND}" -E rm -f "${shim_file}"
COMMAND "${CMAKE_COMMAND}" -E create_symlink "${real_file}" "${shim_file}"
VERBATIM
)
list(APPEND PIM_GENERATED_PATH_SHIM_OUTPUTS "${shim_file}")
set(PIM_GENERATED_PATH_SHIM_OUTPUTS "${PIM_GENERATED_PATH_SHIM_OUTPUTS}" PARENT_SCOPE)
endfunction()
file(GLOB_RECURSE pim_generated_path_scan_sources
CONFIGURE_DEPENDS
"${PIM_SRC_ROOT}/*.cpp"
"${PIM_SRC_ROOT}/*.hpp"
)
set(pim_generated_path_shims)
foreach (source_file IN LISTS pim_generated_path_scan_sources)
file(READ "${source_file}" source_contents)
string(REGEX MATCHALL "#include \"src/Accelerators/PIM/[^\"]+\\.inc\"" source_inc_matches "${source_contents}")
foreach (inc_match IN LISTS source_inc_matches)
string(REGEX REPLACE "^#include \"src/Accelerators/PIM/(.+)\"$" "\\1" relative_inc_path "${inc_match}")
list(APPEND pim_generated_path_shims "${relative_inc_path}")
endforeach ()
endforeach ()
list(REMOVE_DUPLICATES pim_generated_path_shims)
foreach (relative_inc_path IN LISTS pim_generated_path_shims)
add_pim_generated_path_shim("${relative_inc_path}")
endforeach ()
add_custom_target(OMPimGeneratedPathShims DEPENDS ${PIM_GENERATED_PATH_SHIM_OUTPUTS})
set(PIM_GENERATED_PATH_SHIM_TARGET OMPimGeneratedPathShims)
endif ()
set(PIM_PUBLIC_INCLUDE_DIRS
${ONNX_MLIR_SRC_ROOT}/include
${ONNX_MLIR_SRC_ROOT}
@@ -37,6 +87,9 @@ set(PIM_GENERATED_INCLUDE_DIRS
function(add_pim_library name)
add_onnx_mlir_library(${name} STATIC ${ARGN})
if (PIM_GENERATED_PATH_SHIM_TARGET)
add_dependencies(${name} ${PIM_GENERATED_PATH_SHIM_TARGET})
endif ()
endfunction()
add_subdirectory(Dialect)
@@ -68,5 +121,8 @@ add_pim_library(OMPIMAccel
OMSpatialToPim
OMPimCommon
OMPimBufferization
OMPimMemoryCoalescing
OMPimHostConstantFolding
OMPimVerification
MLIRTensorInferTypeOpInterfaceImpl
)
+17 -1
View File
@@ -1,5 +1,19 @@
add_pim_library(OMPimCommon
PimCommon.cpp
IR/AffineUtils.cpp
IR/AddressAnalysis.cpp
IR/BatchCoreUtils.cpp
IR/ConstantUtils.cpp
IR/CoreBlockUtils.cpp
IR/EntryPointUtils.cpp
IR/LoopUtils.cpp
IR/ShapeUtils.cpp
IR/SubviewUtils.cpp
IR/WeightUtils.cpp
Support/CheckedArithmetic.cpp
Support/DebugDump.cpp
Support/Diagnostics.cpp
Support/FileSystemUtils.cpp
Support/ReportUtils.cpp
EXCLUDE_FROM_OM_LIBS
@@ -7,6 +21,8 @@ add_pim_library(OMPimCommon
${PIM_PUBLIC_INCLUDE_DIRS}
LINK_LIBS PUBLIC
MLIRLinalgDialect
MLIRSCFDialect
onnx
SpatialOps
PimOps
+807
View File
@@ -0,0 +1,807 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include <limits>
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp)
return {};
return moduleOp.lookupSymbol<mlir::memref::GlobalOp>(getGlobalOp.getName());
}
namespace {
mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledge) {
if (!knowledge)
return value;
auto iter = knowledge->aliases.find(value);
while (iter != knowledge->aliases.end()) {
value = iter->second;
iter = knowledge->aliases.find(value);
}
return value;
}
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value);
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value);
template <typename... Args>
CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) {
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::forward<Args>(args)...));
}
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (mlir::isa<mlir::BlockArgument>(value))
return value;
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return value;
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
if (auto result = mlir::dyn_cast<mlir::OpResult>(value))
if (mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp))
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
return value;
}
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
static llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefStrides(mlir::MemRefType type) {
llvm::SmallVector<int64_t> strides;
int64_t offset = 0;
if (failed(type.getStridesAndOffset(strides, offset)))
return mlir::failure();
if (llvm::any_of(strides, mlir::ShapedType::isDynamic))
return mlir::failure();
return strides;
}
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
const StaticValueKnowledge* knowledge) {
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
if (!getGlobalOp)
return mlir::failure();
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
if (!denseAttr || !globalType || !globalType.hasStaticShape())
return mlir::failure();
auto elementType = denseAttr.getElementType();
if (!elementType.isIndex() && !elementType.isInteger())
return mlir::failure();
llvm::SmallVector<int64_t> indices;
indices.reserve(loadOp.getIndices().size());
for (mlir::Value index : loadOp.getIndices()) {
auto resolvedIndex = resolveIndexValueImpl(index, knowledge);
if (failed(resolvedIndex))
return mlir::failure();
indices.push_back(*resolvedIndex);
}
if (indices.size() != static_cast<size_t>(globalType.getRank()))
return mlir::failure();
auto strides = computeRowMajorStrides(globalType.getShape());
int64_t linearIndex = linearizeIndex(indices, strides);
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
return mlir::failure();
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
}
static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) {
switch (predicate) {
case mlir::arith::CmpIPredicate::eq: return lhs == rhs;
case mlir::arith::CmpIPredicate::ne: return lhs != rhs;
case mlir::arith::CmpIPredicate::slt: return lhs < rhs;
case mlir::arith::CmpIPredicate::sle: return lhs <= rhs;
case mlir::arith::CmpIPredicate::sgt: return lhs > rhs;
case mlir::arith::CmpIPredicate::sge: return lhs >= rhs;
case mlir::arith::CmpIPredicate::ult: return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::ule: return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::ugt: return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::uge: return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
}
llvm_unreachable("unknown cmpi predicate");
}
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr,
const StaticValueKnowledge& knowledge) {
if (!expr.node)
return mlir::failure();
switch (expr.node->kind) {
case CompiledIndexExprNode::Kind::Constant: return expr.node->constant;
case CompiledIndexExprNode::Kind::Symbol: {
auto value = resolveAlias(expr.node->symbol, &knowledge);
auto iter = knowledge.indexValues.find(value);
if (iter != knowledge.indexValues.end())
return iter->second;
return mlir::failure();
}
case CompiledIndexExprNode::Kind::Add:
case CompiledIndexExprNode::Kind::Sub:
case CompiledIndexExprNode::Kind::Mul:
case CompiledIndexExprNode::Kind::DivUI:
case CompiledIndexExprNode::Kind::DivSI:
case CompiledIndexExprNode::Kind::RemUI:
case CompiledIndexExprNode::Kind::RemSI:
case CompiledIndexExprNode::Kind::MinUI:
case CompiledIndexExprNode::Kind::CmpI: {
auto lhs = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
auto rhs = evaluateCompiledIndexExpr(expr.node->operands[1], knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
switch (expr.node->kind) {
case CompiledIndexExprNode::Kind::Add: return *lhs + *rhs;
case CompiledIndexExprNode::Kind::Sub: return *lhs - *rhs;
case CompiledIndexExprNode::Kind::Mul: return *lhs * *rhs;
case CompiledIndexExprNode::Kind::DivUI:
if (*rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
case CompiledIndexExprNode::Kind::DivSI:
if (*rhs == 0 || (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1))
return mlir::failure();
return *lhs / *rhs;
case CompiledIndexExprNode::Kind::RemUI:
if (*rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
case CompiledIndexExprNode::Kind::RemSI:
if (*rhs == 0)
return mlir::failure();
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
return 0;
return *lhs % *rhs;
case CompiledIndexExprNode::Kind::MinUI:
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
case CompiledIndexExprNode::Kind::CmpI: return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
default: llvm_unreachable("unexpected binary compiled index kind");
}
}
case CompiledIndexExprNode::Kind::Select: {
auto condition = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
if (failed(condition))
return mlir::failure();
return evaluateCompiledIndexExpr(*condition != 0 ? expr.node->operands[1] : expr.node->operands[2], knowledge);
}
case CompiledIndexExprNode::Kind::ConstantGlobalLoad: {
if (!expr.node->globalOp || !expr.node->globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*expr.node->globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(expr.node->globalOp.getType());
if (!denseAttr || !globalType)
return mlir::failure();
llvm::SmallVector<int64_t> indices;
indices.reserve(expr.node->operands.size());
for (const CompiledIndexExpr& operand : expr.node->operands) {
auto resolvedIndex = evaluateCompiledIndexExpr(operand, knowledge);
if (failed(resolvedIndex))
return mlir::failure();
indices.push_back(*resolvedIndex);
}
int64_t linearIndex = linearizeIndex(indices, expr.node->globalStrides);
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
return mlir::failure();
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
}
}
llvm_unreachable("unknown compiled index kind");
}
llvm::FailureOr<CompiledIndexExpr> compileConstantGlobalLoad(mlir::memref::LoadOp loadOp) {
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
if (!getGlobalOp)
return mlir::failure();
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
if (!denseAttr || !globalType || !globalType.hasStaticShape())
return mlir::failure();
auto elementType = denseAttr.getElementType();
if (!elementType.isIndex() && !elementType.isInteger())
return mlir::failure();
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::ConstantGlobalLoad;
expr.globalOp = globalOp;
expr.globalStrides = computeRowMajorStrides(globalType.getShape());
expr.operands.reserve(loadOp.getIndices().size());
for (mlir::Value index : loadOp.getIndices()) {
auto compiledIndex = compileIndexValueImpl(index);
if (failed(compiledIndex))
return mlir::failure();
expr.operands.push_back(*compiledIndex);
}
return makeCompiledIndexExpr(std::move(expr));
}
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value) {
if (auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>()) {
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue())) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = integerAttr.getInt();
return makeCompiledIndexExpr(std::move(expr));
}
}
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Symbol;
expr.symbol = value;
return makeCompiledIndexExpr(std::move(expr));
}
auto buildBinaryExpr = [&](CompiledIndexExprNode::Kind kind, mlir::Value lhsValue, mlir::Value rhsValue) {
auto lhs = compileIndexValueImpl(lhsValue);
auto rhs = compileIndexValueImpl(rhsValue);
if (failed(lhs) || failed(rhs))
return llvm::FailureOr<CompiledIndexExpr>(mlir::failure());
CompiledIndexExprNode expr;
expr.kind = kind;
expr.operands = {*lhs, *rhs};
return llvm::FailureOr<CompiledIndexExpr>(makeCompiledIndexExpr(std::move(expr)));
};
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
return compileIndexValueImpl(indexCastOp.getIn());
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Add, addOp.getLhs(), addOp.getRhs());
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Sub, subOp.getLhs(), subOp.getRhs());
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Mul, mulOp.getLhs(), mulOp.getRhs());
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivUI, divOp.getLhs(), divOp.getRhs());
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivSI, divOp.getLhs(), divOp.getRhs());
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::MinUI, minOp.getLhs(), minOp.getRhs());
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemUI, remOp.getLhs(), remOp.getRhs());
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemSI, remOp.getLhs(), remOp.getRhs());
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
auto expr = buildBinaryExpr(CompiledIndexExprNode::Kind::CmpI, cmpOp.getLhs(), cmpOp.getRhs());
if (failed(expr))
return mlir::failure();
auto exprNode = std::make_shared<CompiledIndexExprNode>(*expr->node);
exprNode->predicate = cmpOp.getPredicate();
return CompiledIndexExpr(exprNode);
}
if (auto maxOp = mlir::dyn_cast<mlir::arith::MaxUIOp>(definingOp)) {
auto lhs = compileIndexValueImpl(maxOp.getLhs());
auto rhs = compileIndexValueImpl(maxOp.getRhs());
if (failed(lhs) || failed(rhs))
return mlir::failure();
CompiledIndexExprNode cmpExpr;
cmpExpr.kind = CompiledIndexExprNode::Kind::CmpI;
cmpExpr.predicate = mlir::arith::CmpIPredicate::uge;
cmpExpr.operands = {*lhs, *rhs};
CompiledIndexExprNode selectExpr;
selectExpr.kind = CompiledIndexExprNode::Kind::Select;
selectExpr.operands = {makeCompiledIndexExpr(std::move(cmpExpr)), *lhs, *rhs};
return makeCompiledIndexExpr(std::move(selectExpr));
}
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
auto condition = compileIndexValueImpl(selectOp.getCondition());
auto trueValue = compileIndexValueImpl(selectOp.getTrueValue());
auto falseValue = compileIndexValueImpl(selectOp.getFalseValue());
if (failed(condition) || failed(trueValue) || failed(falseValue))
return mlir::failure();
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Select;
expr.operands = {*condition, *trueValue, *falseValue};
return makeCompiledIndexExpr(std::move(expr));
}
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
return compileConstantGlobalLoad(loadOp);
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Symbol;
expr.symbol = value;
return makeCompiledIndexExpr(std::move(expr));
}
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (knowledge) {
auto iter = knowledge->indexValues.find(value);
if (iter != knowledge->indexValues.end())
return iter->second;
}
auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>();
if (constantOp) {
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue()))
return integerAttr.getInt();
}
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return *lhs + *rhs;
}
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return *lhs - *rhs;
}
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return *lhs * *rhs;
}
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
}
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
return mlir::failure();
return *lhs / *rhs;
}
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
}
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
}
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
return 0;
return *lhs % *rhs;
}
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(cmpOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(cmpOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return evaluateCmpPredicate(cmpOp.getPredicate(), *lhs, *rhs) ? 1 : 0;
}
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
auto condition = resolveIndexValueImpl(selectOp.getCondition(), knowledge);
if (failed(condition))
return mlir::failure();
return resolveIndexValueImpl(*condition != 0 ? selectOp.getTrueValue() : selectOp.getFalseValue(), knowledge);
}
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
return resolveConstantGlobalLoad(loadOp, knowledge);
return mlir::failure();
}
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
if (auto attr = mlir::dyn_cast<mlir::Attribute>(ofr)) {
auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
if (!integerAttr)
return mlir::failure();
return integerAttr.getInt();
}
return resolveIndexValueImpl(mlir::cast<mlir::Value>(ofr), knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Value value,
const StaticValueKnowledge* knowledge) {
int64_t byteOffset = 0;
value = resolveAlias(value, knowledge);
while (true) {
if (mlir::isa<mlir::BlockArgument>(value))
return ResolvedContiguousAddress {value, byteOffset};
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
if (!tiedOperand)
return mlir::failure();
value = resolveAlias(tiedOperand->get(), knowledge);
continue;
}
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
mlir::Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
continue;
}
}
value = yieldedValue;
continue;
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return mlir::failure();
llvm::SmallVector<int64_t> offsets;
llvm::SmallVector<int64_t> sizes;
llvm::SmallVector<int64_t> strides;
offsets.reserve(subviewOp.getMixedOffsets().size());
sizes.reserve(subviewOp.getMixedSizes().size());
strides.reserve(subviewOp.getMixedStrides().size());
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
if (failed(resolvedOffset))
return mlir::failure();
offsets.push_back(*resolvedOffset);
}
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
auto resolvedSize = resolveOpFoldResult(size, knowledge);
if (failed(resolvedSize))
return mlir::failure();
sizes.push_back(*resolvedSize);
}
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
if (failed(resolvedStride))
return mlir::failure();
strides.push_back(*resolvedStride);
}
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return mlir::failure();
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(sourceStrides))
return mlir::failure();
byteOffset += linearizeIndex(offsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
value = resolveAlias(subviewOp.getSource(), knowledge);
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
value = resolveAlias(castOp.getSource(), knowledge);
continue;
}
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
value = resolveAlias(collapseOp.getSrc(), knowledge);
continue;
}
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
value = resolveAlias(expandOp.getSrc(), knowledge);
continue;
}
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress {value, byteOffset};
return mlir::failure();
}
}
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value) {
int64_t constantByteOffset = 0;
CompiledIndexExpr byteOffsetExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = 0;
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
}
while (true) {
if (mlir::isa<mlir::BlockArgument>(value))
return CompiledAddressExpr {value, byteOffsetExpr};
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
if (!tiedOperand)
return mlir::failure();
value = tiedOperand->get();
continue;
}
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
mlir::Value yieldedValue = yieldOp.getOperand(result.getResultNumber());
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = forOp.getInitArgs()[blockArgument.getArgNumber() - 1];
continue;
}
}
value = yieldedValue;
continue;
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return mlir::failure();
llvm::SmallVector<int64_t> staticSizes;
staticSizes.reserve(subviewOp.getMixedSizes().size());
llvm::SmallVector<int64_t> staticStrides;
staticStrides.reserve(subviewOp.getMixedStrides().size());
llvm::SmallVector<int64_t> staticOffsets;
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
bool hasOnlyStaticOffsets = true;
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets())
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
hasOnlyStaticOffsets = false;
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
auto attr = mlir::dyn_cast<mlir::Attribute>(size);
if (!attr)
return mlir::failure();
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
}
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
auto attr = mlir::dyn_cast<mlir::Attribute>(stride);
if (!attr)
return mlir::failure();
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
}
if (!isContiguousSubviewWithDynamicOffsets(
sourceType.getShape(), subviewOp.getMixedOffsets(), staticSizes, staticStrides)) {
return mlir::failure();
}
if (hasOnlyStaticOffsets) {
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
return mlir::failure();
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(sourceStrides))
return mlir::failure();
constantByteOffset +=
linearizeIndex(staticOffsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
}
else {
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(sourceStrides))
return mlir::failure();
CompiledIndexExpr offsetExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = 0;
offsetExpr = makeCompiledIndexExpr(std::move(expr));
}
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) {
CompiledIndexExpr operandExpr;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(mixedOffset)) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = mlir::cast<mlir::IntegerAttr>(attr).getInt() * sourceStride
* getElementTypeSizeInBytes(subviewType.getElementType());
operandExpr = makeCompiledIndexExpr(std::move(expr));
}
else {
auto compiledOffset = compileIndexValueImpl(mlir::cast<mlir::Value>(mixedOffset));
if (failed(compiledOffset))
return mlir::failure();
CompiledIndexExpr scaleExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = sourceStride * getElementTypeSizeInBytes(subviewType.getElementType());
scaleExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Mul;
expr.operands = {*compiledOffset, scaleExpr};
operandExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {offsetExpr, operandExpr};
offsetExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExpr constantExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = constantByteOffset;
constantExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {constantExpr, offsetExpr};
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
constantByteOffset = 0;
}
value = subviewOp.getSource();
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp)) {
if (constantByteOffset != 0) {
CompiledIndexExpr constantExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = constantByteOffset;
constantExpr = makeCompiledIndexExpr(std::move(expr));
}
if (byteOffsetExpr.node->kind == CompiledIndexExprNode::Kind::Constant && byteOffsetExpr.node->constant == 0)
byteOffsetExpr = constantExpr;
else {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {constantExpr, byteOffsetExpr};
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
}
}
return CompiledAddressExpr {value, byteOffsetExpr};
}
return mlir::failure();
}
}
} // namespace
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
return resolveIndexValueImpl(value, &knowledge);
}
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value) { return compileIndexValueImpl(value); }
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge) {
return resolveContiguousAddressImpl(value, &knowledge);
}
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge) {
return resolveLoopCarriedAliasImpl(value, &knowledge);
}
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value) {
return compileContiguousAddressExprImpl(value);
}
llvm::FailureOr<int64_t> CompiledIndexExpr::evaluate(const StaticValueKnowledge& knowledge) const {
return evaluateCompiledIndexExpr(*this, knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress> CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge,
std::optional<unsigned> lane) const {
(void) lane;
auto resolvedOffset = byteOffset.evaluate(knowledge);
if (failed(resolvedOffset))
return mlir::failure();
return ResolvedContiguousAddress {base, *resolvedOffset};
}
} // namespace onnx_mlir
+94
View File
@@ -0,0 +1,94 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include <memory>
#include <optional>
namespace onnx_mlir {
/// Describes a value as a base addressable object plus a statically known
/// byte offset after peeling aliases, casts, and contiguous subviews.
struct ResolvedContiguousAddress {
mlir::Value base;
int64_t byteOffset = 0;
};
/// Records compile-time facts used when interpreting address arithmetic and
/// loop-carried aliases inside PIM regions.
struct StaticValueKnowledge {
llvm::DenseMap<mlir::Value, int64_t> indexValues;
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
StaticValueKnowledge() {}
};
struct CompiledIndexExprNode;
struct CompiledIndexExpr {
std::shared_ptr<CompiledIndexExprNode> node;
CompiledIndexExpr() = default;
explicit CompiledIndexExpr(std::shared_ptr<CompiledIndexExprNode> node)
: node(std::move(node)) {}
llvm::FailureOr<int64_t> evaluate(const StaticValueKnowledge& knowledge) const;
};
struct CompiledIndexExprNode {
enum class Kind {
Constant,
Symbol,
Add,
Sub,
Mul,
DivUI,
DivSI,
RemUI,
RemSI,
MinUI,
CmpI,
Select,
ConstantGlobalLoad
};
Kind kind = Kind::Constant;
int64_t constant = 0;
mlir::Value symbol;
mlir::arith::CmpIPredicate predicate = mlir::arith::CmpIPredicate::eq;
mlir::memref::GlobalOp globalOp;
llvm::SmallVector<int64_t, 4> globalStrides;
llvm::SmallVector<CompiledIndexExpr, 4> operands;
};
struct CompiledAddressExpr {
mlir::Value base;
CompiledIndexExpr byteOffset;
llvm::FailureOr<ResolvedContiguousAddress> evaluate(const StaticValueKnowledge& knowledge,
std::optional<unsigned> lane) const;
};
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
/// Resolves a value to contiguous backing storage when that storage can be
/// proven statically from aliases, DPS ties, casts, and subviews.
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge = {});
/// Statically evaluates index-like SSA values, including simple integer
/// arithmetic and loop facts recorded in `knowledge`.
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {});
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value);
/// Follows alias, view, and DPS chains to recover the backing value of a
/// loop-carried memref/result.
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value);
} // namespace onnx_mlir
+182
View File
@@ -0,0 +1,182 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Matchers.h"
#include "AffineUtils.hpp"
#include "ConstantUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
static FailureOr<int64_t> floorDivSigned(int64_t lhs, int64_t rhs) {
if (rhs <= 0)
return failure();
int64_t quotient = lhs / rhs;
int64_t remainder = lhs % rhs;
if (remainder != 0 && lhs < 0)
--quotient;
return quotient;
}
static FailureOr<int64_t> ceilDivSigned(int64_t lhs, int64_t rhs) {
if (rhs <= 0)
return failure();
int64_t quotient = lhs / rhs;
int64_t remainder = lhs % rhs;
if (remainder != 0 && lhs > 0)
++quotient;
return quotient;
}
Value createOrFoldAffineApply(
RewriterBase& rewriter, Location loc, AffineMap map, ValueRange operands, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
assert(map.getNumResults() == 1 && "affine.apply expects a single-result affine map");
SmallVector<Attribute> operandConstants;
operandConstants.reserve(operands.size());
for (Value operand : operands) {
std::optional<int64_t> constantValue = matchConstantIndexValue(operand);
if (!constantValue)
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
operandConstants.push_back(rewriter.getIndexAttr(*constantValue));
}
SmallVector<Attribute> foldedResults;
if (succeeded(map.constantFold(operandConstants, foldedResults)) && foldedResults.size() == 1)
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
return getOrCreateIndexConstant(rewriter, constantAnchor, constantResult.getInt());
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
}
Value createOrFoldAffineApply(
RewriterBase& rewriter, Location loc, AffineExpr expr, ValueRange dims, Operation* constantAnchor) {
AffineMap map = AffineMap::get(/*dimCount=*/dims.size(), /*symbolCount=*/0, expr);
return createOrFoldAffineApply(rewriter, loc, map, dims, constantAnchor);
}
Value affineMulConst(RewriterBase& rewriter, Location loc, Value value, int64_t multiplier, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
if (multiplier == 0)
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
if (multiplier == 1)
return value;
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
return createOrFoldAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value}, constantAnchor);
}
Value affineModConst(RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
assert(divisor > 0 && "expected a positive affine.mod divisor");
if (divisor == 1)
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
return createOrFoldAffineApply(rewriter, loc, d0 % divisor, ValueRange {value}, constantAnchor);
}
Value affineFloorDivConst(
RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
assert(divisor > 0 && "expected a positive affine.floor_div divisor");
if (divisor == 1)
return value;
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
}
FailureOr<int64_t> evaluateAffineExpr(AffineExpr expr, ArrayRef<int64_t> dims, ArrayRef<int64_t> symbols) {
if (auto constant = dyn_cast<AffineConstantExpr>(expr))
return constant.getValue();
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
unsigned position = dim.getPosition();
if (position >= dims.size())
return failure();
return dims[position];
}
if (auto symbol = dyn_cast<AffineSymbolExpr>(expr)) {
unsigned position = symbol.getPosition();
if (position >= symbols.size())
return failure();
return symbols[position];
}
auto binary = dyn_cast<AffineBinaryOpExpr>(expr);
if (!binary)
return failure();
FailureOr<int64_t> lhs = evaluateAffineExpr(binary.getLHS(), dims, symbols);
FailureOr<int64_t> rhs = evaluateAffineExpr(binary.getRHS(), dims, symbols);
if (failed(lhs) || failed(rhs))
return failure();
switch (binary.getKind()) {
case AffineExprKind::Add: return *lhs + *rhs;
case AffineExprKind::Mul: return *lhs * *rhs;
case AffineExprKind::FloorDiv: return floorDivSigned(*lhs, *rhs);
case AffineExprKind::CeilDiv: return ceilDivSigned(*lhs, *rhs);
case AffineExprKind::Mod: {
FailureOr<int64_t> div = floorDivSigned(*lhs, *rhs);
if (failed(div))
return failure();
return *lhs - *div * *rhs;
}
default: return failure();
}
}
FailureOr<int64_t> evaluateSingleResultAffineMap(AffineMap map, ArrayRef<int64_t> operands) {
if (map.getNumResults() != 1 || operands.size() != map.getNumInputs())
return failure();
ArrayRef<int64_t> dims(operands.data(), map.getNumDims());
ArrayRef<int64_t> symbols(operands.data() + map.getNumDims(), map.getNumSymbols());
return evaluateAffineExpr(map.getResult(0), dims, symbols);
}
FailureOr<int64_t> evaluateAffineApply(affine::AffineApplyOp affineApply, IndexValueResolver resolver) {
SmallVector<int64_t, 4> operands;
operands.reserve(affineApply.getMapOperands().size());
for (Value operand : affineApply.getMapOperands()) {
FailureOr<int64_t> folded = resolver(operand);
if (failed(folded))
return failure();
operands.push_back(*folded);
}
return evaluateSingleResultAffineMap(affineApply.getAffineMap(), operands);
}
bool isSingleResultSymbolFreeAffineMap(AffineMap map) { return map.getNumResults() == 1 && map.getNumSymbols() == 0; }
bool isDimAndConstantAffineExpr(AffineExpr expr) {
switch (expr.getKind()) {
case AffineExprKind::Constant:
case AffineExprKind::DimId: return true;
case AffineExprKind::SymbolId: return false;
case AffineExprKind::Add: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDimAndConstantAffineExpr(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS());
}
case AffineExprKind::Mul: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return (isa<AffineConstantExpr>(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS()))
|| (isa<AffineConstantExpr>(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS()));
}
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isa<AffineConstantExpr>(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS());
}
}
llvm_unreachable("unexpected affine expression kind");
}
} // namespace onnx_mlir
+55
View File
@@ -0,0 +1,55 @@
#pragma once
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/FunctionExtras.h"
namespace onnx_mlir {
using IndexValueResolver = llvm::function_ref<llvm::FailureOr<int64_t>(mlir::Value)>;
mlir::Value createOrFoldAffineApply(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::AffineMap map,
mlir::ValueRange operands,
mlir::Operation* constantAnchor);
mlir::Value createOrFoldAffineApply(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::AffineExpr expr,
mlir::ValueRange dims,
mlir::Operation* constantAnchor);
mlir::Value affineMulConst(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::Value value,
int64_t multiplier,
mlir::Operation* constantAnchor);
mlir::Value affineModConst(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::Value value,
int64_t divisor,
mlir::Operation* constantAnchor);
mlir::Value affineFloorDivConst(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::Value value,
int64_t divisor,
mlir::Operation* constantAnchor);
llvm::FailureOr<int64_t>
evaluateAffineExpr(mlir::AffineExpr expr, llvm::ArrayRef<int64_t> dims, llvm::ArrayRef<int64_t> symbols = {});
llvm::FailureOr<int64_t> evaluateSingleResultAffineMap(mlir::AffineMap map, llvm::ArrayRef<int64_t> operands);
llvm::FailureOr<int64_t> evaluateAffineApply(mlir::affine::AffineApplyOp affineApply, IndexValueResolver resolver);
bool isSingleResultSymbolFreeAffineMap(mlir::AffineMap map);
bool isDimAndConstantAffineExpr(mlir::AffineExpr expr);
} // namespace onnx_mlir
+32
View File
@@ -0,0 +1,32 @@
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
namespace onnx_mlir {
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<mlir::DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
llvm::SmallVector<int32_t> laneCoreIds;
laneCoreIds.reserve(coreIds.size() / laneCount);
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
return laneCoreIds;
}
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex) {
if (mlir::isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 3;
if (mlir::isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
}
bool isExplicitDevToHostTargetOperand(mlir::Operation* op, unsigned operandIndex) {
return mlir::isa<pim::PimMemCopyDevToHostOp>(op) && operandIndex == 2;
}
} // namespace onnx_mlir
+18
View File
@@ -0,0 +1,18 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex);
bool isExplicitDevToHostTargetOperand(mlir::Operation* op, unsigned operandIndex);
} // namespace onnx_mlir
+745
View File
@@ -0,0 +1,745 @@
#ifndef ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
#define ONNX_MLIR_PIM_COMPACT_ASM_UTILS_HPP
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/LogicalResult.h"
namespace onnx_mlir {
namespace compact_asm {
using namespace mlir;
enum class ListDelimiter {
Square,
Paren
};
inline ParseResult parseOpenDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseLSquare();
return parser.parseLParen();
}
inline ParseResult parseOptionalCloseDelimiter(OpAsmParser& parser, ListDelimiter delimiter) {
if (delimiter == ListDelimiter::Square)
return parser.parseOptionalRSquare();
return parser.parseOptionalRParen();
}
template <typename StreamT>
inline void printOpenDelimiter(StreamT& stream, ListDelimiter delimiter) {
stream << (delimiter == ListDelimiter::Square ? "[" : "(");
}
template <typename StreamT>
inline void printCloseDelimiter(StreamT& stream, ListDelimiter delimiter) {
stream << (delimiter == ListDelimiter::Square ? "]" : ")");
}
template <typename EntryT, typename ParseEntryFn>
inline ParseResult parseCompressedRepeatedList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<EntryT>& entries,
ParseEntryFn parseEntry) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
EntryT entry;
if (parseEntry(entry))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
entries.push_back(entry);
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
template <typename IntT>
inline ParseResult
parseCompressedIntegerEntries(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<IntT> subgroup;
if (parseCompressedIntegerEntries(parser, ListDelimiter::Paren, subgroup))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(values, subgroup);
}
else {
int64_t first = 0;
if (parser.parseInteger(first))
return failure();
if (succeeded(parser.parseOptionalKeyword("to"))) {
int64_t last = 0;
if (parser.parseInteger(last) || last < first)
return parser.emitError(parser.getCurrentLocation(), "invalid ascending range");
int64_t step = 1;
if (succeeded(parser.parseOptionalKeyword("by"))) {
if (parser.parseInteger(step) || step <= 0)
return parser.emitError(parser.getCurrentLocation(), "step after 'by' must be positive");
}
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
if ((last - first) % step != 0) {
return parser.emitError(parser.getCurrentLocation(),
"range end must be reachable from start using the given step");
}
for (int64_t value = first; value <= last; value += step)
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(value));
}
else {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
values.push_back(static_cast<IntT>(first));
}
}
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
template <typename IntT>
inline ParseResult
parseCompressedIntegerSequence(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<IntT>& values) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
return parseCompressedIntegerEntries(parser, delimiter, values);
}
template <typename RangeT, typename PrintEntryFn>
inline void printCompressedEqualRuns(OpAsmPrinter& printer, RangeT entries, PrintEntryFn printEntry) {
for (size_t index = 0; index < entries.size();) {
size_t runEnd = index + 1;
while (runEnd < entries.size() && entries[runEnd] == entries[index])
++runEnd;
if (index != 0)
printer << ", ";
printEntry(entries[index]);
size_t runLength = runEnd - index;
if (runLength > 1)
printer << " x" << runLength;
index = runEnd;
}
}
template <typename StreamT, typename IntT>
inline void printCompressedIntegerEntries(StreamT& stream, ArrayRef<IntT> values) {
struct FlatCompression {
enum class Kind {
Single,
EqualRun,
Progression
};
Kind kind = Kind::Single;
size_t covered = 1;
size_t repeatCount = 1;
size_t progressionValueCount = 1;
int64_t step = 1;
IntT firstValue {};
IntT lastValue {};
};
auto computeFlatCompression = [&](size_t start) {
FlatCompression compression;
compression.firstValue = values[start];
compression.lastValue = values[start];
auto findEqualRunEnd = [&](size_t runStart) {
size_t runEnd = runStart + 1;
while (runEnd < values.size() && values[runEnd] == values[runStart])
++runEnd;
return runEnd;
};
size_t firstRunEnd = findEqualRunEnd(start);
compression.repeatCount = firstRunEnd - start;
size_t progressionEnd = firstRunEnd;
int64_t step = 0;
IntT lastValue = values[start];
if (firstRunEnd < values.size()) {
size_t secondRunEnd = findEqualRunEnd(firstRunEnd);
step = static_cast<int64_t>(values[firstRunEnd]) - static_cast<int64_t>(values[start]);
if (step > 0 && secondRunEnd - firstRunEnd == compression.repeatCount) {
progressionEnd = secondRunEnd;
lastValue = values[firstRunEnd];
size_t currentRunStart = secondRunEnd;
while (currentRunStart < values.size()) {
size_t currentRunEnd = findEqualRunEnd(currentRunStart);
if (currentRunEnd - currentRunStart != compression.repeatCount)
break;
if (static_cast<int64_t>(values[currentRunStart]) != static_cast<int64_t>(lastValue) + step)
break;
lastValue = values[currentRunStart];
progressionEnd = currentRunEnd;
currentRunStart = currentRunEnd;
}
}
else {
step = 0;
}
}
compression.covered = 1;
if (progressionEnd > firstRunEnd) {
size_t progressionValueCount = (progressionEnd - start) / compression.repeatCount;
if (progressionValueCount >= 3) {
compression.kind = FlatCompression::Kind::Progression;
compression.covered = progressionEnd - start;
compression.progressionValueCount = progressionValueCount;
compression.step = step;
compression.lastValue = lastValue;
return compression;
}
}
if (compression.repeatCount > 1) {
compression.kind = FlatCompression::Kind::EqualRun;
compression.covered = compression.repeatCount;
return compression;
}
return compression;
};
auto findRepeatedSublist = [&](size_t start) {
size_t bestLength = 0;
size_t bestRepeatCount = 1;
size_t remaining = values.size() - start;
for (size_t length = 2; length * 2 <= remaining; ++length) {
size_t repeatCount = 1;
ArrayRef<IntT> candidate = values.slice(start, length);
while (start + (repeatCount + 1) * length <= values.size()
&& llvm::equal(candidate, values.slice(start + repeatCount * length, length))) {
++repeatCount;
}
if (repeatCount <= 1)
continue;
size_t covered = length * repeatCount;
size_t bestCovered = bestLength * bestRepeatCount;
if (covered > bestCovered || (covered == bestCovered && length < bestLength)) {
bestLength = length;
bestRepeatCount = repeatCount;
}
}
return std::pair(bestLength, bestRepeatCount);
};
for (size_t index = 0; index < values.size();) {
if (index != 0)
stream << ", ";
FlatCompression flat = computeFlatCompression(index);
auto [sublistLength, sublistRepeatCount] = findRepeatedSublist(index);
size_t repeatedSublistCoverage = sublistLength * sublistRepeatCount;
if (sublistRepeatCount > 1 && sublistLength > 1 && repeatedSublistCoverage > flat.covered) {
printOpenDelimiter(stream, ListDelimiter::Paren);
printCompressedIntegerEntries(stream, values.slice(index, sublistLength));
printCloseDelimiter(stream, ListDelimiter::Paren);
stream << " x" << sublistRepeatCount;
index += repeatedSublistCoverage;
continue;
}
switch (flat.kind) {
case FlatCompression::Kind::Progression:
stream << flat.firstValue << " to " << flat.lastValue;
if (flat.step != 1)
stream << " by " << flat.step;
if (flat.repeatCount > 1)
stream << " x" << flat.repeatCount;
index += flat.covered;
break;
case FlatCompression::Kind::EqualRun:
stream << flat.firstValue << " x" << flat.repeatCount;
index += flat.covered;
break;
case FlatCompression::Kind::Single:
stream << flat.firstValue;
index += flat.covered;
break;
}
}
}
template <typename StreamT, typename IntT>
inline void printCompressedIntegerSequence(StreamT& stream, ArrayRef<IntT> values, ListDelimiter delimiter) {
printOpenDelimiter(stream, delimiter);
printCompressedIntegerEntries(stream, values);
printCloseDelimiter(stream, delimiter);
}
template <typename IntT>
inline ParseResult parseCompressedIntegerList(OpAsmParser& parser, SmallVectorImpl<IntT>& values) {
return parseCompressedIntegerSequence(parser, ListDelimiter::Square, values);
}
template <typename IntT>
inline void printCompressedIntegerList(OpAsmPrinter& printer, ArrayRef<IntT> values) {
printCompressedIntegerSequence(printer, values, ListDelimiter::Square);
}
inline void printCompressedValueSequence(OpAsmPrinter& printer, ValueRange values) {
for (size_t index = 0; index < values.size();) {
size_t equalRunEnd = index + 1;
while (equalRunEnd < values.size() && values[equalRunEnd] == values[index])
++equalRunEnd;
if (index != 0)
printer << ", ";
if (equalRunEnd - index > 1) {
printer.printOperand(values[index]);
printer << " x" << (equalRunEnd - index);
index = equalRunEnd;
continue;
}
size_t rangeEnd = index + 1;
if (auto firstResult = dyn_cast<OpResult>(values[index])) {
while (rangeEnd < values.size()) {
auto nextResult = dyn_cast<OpResult>(values[rangeEnd]);
if (!nextResult || nextResult.getOwner() != firstResult.getOwner()
|| nextResult.getResultNumber() != firstResult.getResultNumber() + (rangeEnd - index)) {
break;
}
++rangeEnd;
}
}
else if (auto firstArg = dyn_cast<BlockArgument>(values[index])) {
while (rangeEnd < values.size()) {
auto nextArg = dyn_cast<BlockArgument>(values[rangeEnd]);
if (!nextArg || nextArg.getOwner() != firstArg.getOwner()
|| nextArg.getArgNumber() != firstArg.getArgNumber() + (rangeEnd - index)) {
break;
}
++rangeEnd;
}
}
printer.printOperand(values[index]);
if (rangeEnd - index >= 3) {
printer << " to ";
printer.printOperand(values[rangeEnd - 1]);
}
else if (rangeEnd - index == 2) {
printer << ", ";
printer.printOperand(values[index + 1]);
}
index = rangeEnd;
}
}
inline void printCompressedTypeSequence(OpAsmPrinter& printer, TypeRange types) {
printCompressedEqualRuns(printer, types, [&](Type type) { printer.printType(type); });
}
inline void printCompressedValueList(OpAsmPrinter& printer, ValueRange values, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printCompressedValueSequence(printer, values);
printCloseDelimiter(printer, delimiter);
}
inline void printCompressedTypeList(OpAsmPrinter& printer, TypeRange types, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printCompressedTypeSequence(printer, types);
printCloseDelimiter(printer, delimiter);
}
inline ParseResult parseCompressedTypeSequence(OpAsmParser& parser, SmallVectorImpl<Type>& types, bool allowEmpty) {
Type firstType;
OptionalParseResult firstTypeResult = parser.parseOptionalType(firstType);
if (!firstTypeResult.has_value()) {
if (allowEmpty)
return success();
return parser.emitError(parser.getCurrentLocation(), "expected type");
}
if (failed(*firstTypeResult))
return failure();
auto appendType = [&](Type type) -> ParseResult {
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
types.push_back(type);
return success();
};
if (appendType(firstType))
return failure();
while (succeeded(parser.parseOptionalComma())) {
Type nextType;
if (parser.parseType(nextType) || appendType(nextType))
return failure();
}
return success();
}
inline ParseResult parseCompressedOperandEntryWithFirst(OpAsmParser& parser,
OpAsmParser::UnresolvedOperand firstOperand,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (succeeded(parser.parseOptionalKeyword("to"))) {
OpAsmParser::UnresolvedOperand lastOperand;
if (parser.parseOperand(lastOperand))
return failure();
if (firstOperand.name != lastOperand.name || firstOperand.number > lastOperand.number)
return parser.emitError(parser.getCurrentLocation(), "invalid operand range");
for (unsigned number = firstOperand.number; number <= lastOperand.number; ++number)
operands.push_back({firstOperand.location, firstOperand.name, number});
return success();
}
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t index = 0; index < repeatCount; ++index)
operands.push_back(firstOperand);
return success();
}
inline ParseResult parseOneCompressedOperandEntry(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
OpAsmParser::UnresolvedOperand firstOperand;
if (parser.parseOperand(firstOperand))
return failure();
return parseCompressedOperandEntryWithFirst(parser, firstOperand, operands);
}
inline ParseResult parseCompressedOperandList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
while (true) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
inline ParseResult parseCompressedOperandSequence(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
return success();
}
inline ParseResult parseCompressedTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parseCompressedTypeSequence(parser, types, /*allowEmpty=*/false))
return failure();
return parseOptionalCloseDelimiter(parser, delimiter);
}
inline bool hasRepeatedTuple(ValueRange values, size_t tupleSize) {
if (tupleSize == 0 || values.empty() || values.size() % tupleSize != 0)
return false;
SmallVector<Value> valueVec(values.begin(), values.end());
ArrayRef<Value> tuple(valueVec.data(), tupleSize);
for (size_t index = tupleSize; index < values.size(); index += tupleSize)
if (!llvm::equal(tuple, ArrayRef<Value>(valueVec).slice(index, tupleSize)))
return false;
return true;
}
inline bool hasRepeatedTuple(TypeRange types, size_t tupleSize) {
if (tupleSize == 0 || types.empty() || types.size() % tupleSize != 0)
return false;
SmallVector<Type> typeVec(types.begin(), types.end());
ArrayRef<Type> tuple(typeVec.data(), tupleSize);
for (size_t index = tupleSize; index < types.size(); index += tupleSize)
if (!llvm::equal(tuple, ArrayRef<Type>(typeVec).slice(index, tupleSize)))
return false;
return true;
}
inline void printValueTupleRun(OpAsmPrinter& printer, ValueRange values, size_t tupleSize, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) {
if (index != 0)
printer << ", ";
printer.printOperand(values[index]);
}
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " x" << (values.size() / tupleSize);
printCloseDelimiter(printer, delimiter);
}
inline void printTypeTupleRun(OpAsmPrinter& printer, TypeRange types, size_t tupleSize, ListDelimiter delimiter) {
printOpenDelimiter(printer, delimiter);
printOpenDelimiter(printer, ListDelimiter::Paren);
for (size_t index = 0; index < tupleSize; ++index) {
if (index != 0)
printer << ", ";
printer.printType(types[index]);
}
printCloseDelimiter(printer, ListDelimiter::Paren);
printer << " x" << (types.size() / tupleSize);
printCloseDelimiter(printer, delimiter);
}
inline ParseResult parseCompressedOrTupleOperandList(OpAsmParser& parser,
ListDelimiter delimiter,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<OpAsmParser::UnresolvedOperand> tupleOperands;
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(operands, tupleOperands);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseLParen())
return failure();
tupleOperands.clear();
if (parseCompressedOperandSequence(parser, tupleOperands) || parser.parseRParen())
return failure();
repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(operands, tupleOperands);
}
return parseOptionalCloseDelimiter(parser, delimiter);
}
while (true) {
if (parseOneCompressedOperandEntry(parser, operands))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
inline ParseResult
parseCompressedOrTupleTypeList(OpAsmParser& parser, ListDelimiter delimiter, SmallVectorImpl<Type>& types) {
if (parseOpenDelimiter(parser, delimiter))
return failure();
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (succeeded(parser.parseOptionalLParen())) {
SmallVector<Type> tupleTypes;
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(types, tupleTypes);
while (succeeded(parser.parseOptionalComma())) {
if (parser.parseLParen())
return failure();
tupleTypes.clear();
if (parseCompressedTypeSequence(parser, tupleTypes, /*allowEmpty=*/false) || parser.parseRParen())
return failure();
repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
llvm::append_range(types, tupleTypes);
}
return parseOptionalCloseDelimiter(parser, delimiter);
}
while (true) {
Type type;
if (parser.parseType(type))
return failure();
int64_t repeatCount = 1;
if (succeeded(parser.parseOptionalKeyword("x"))) {
if (parser.parseInteger(repeatCount) || repeatCount <= 0)
return parser.emitError(parser.getCurrentLocation(), "repeat count after 'x' must be positive");
}
for (int64_t repeat = 0; repeat < repeatCount; ++repeat)
types.push_back(type);
if (succeeded(parseOptionalCloseDelimiter(parser, delimiter)))
return success();
if (parser.parseComma())
return failure();
}
}
inline void printArgumentBindings(OpAsmPrinter& printer, Block& block, ValueRange operands) {
if (block.getNumArguments() == 0) {
printer << "() = ()";
return;
}
if (block.getNumArguments() == 1) {
printer.printOperand(block.getArgument(0));
printer << " = ";
printCompressedValueList(printer, operands, ListDelimiter::Paren);
return;
}
printCompressedValueList(printer, ValueRange(block.getArguments()), ListDelimiter::Paren);
printer << " = ";
printCompressedValueList(printer, operands, ListDelimiter::Paren);
}
inline ParseResult parseCompressedArgumentEntryWithFirst(OpAsmParser& parser,
OpAsmParser::Argument firstArgument,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
if (succeeded(parser.parseOptionalKeyword("to"))) {
OpAsmParser::Argument lastArgument;
if (parser.parseArgument(lastArgument))
return failure();
if (firstArgument.ssaName.name != lastArgument.ssaName.name
|| firstArgument.ssaName.number > lastArgument.ssaName.number) {
return parser.emitError(parser.getCurrentLocation(), "invalid argument range");
}
for (unsigned number = firstArgument.ssaName.number; number <= lastArgument.ssaName.number; ++number) {
OpAsmParser::Argument argument;
argument.ssaName = {firstArgument.ssaName.location, firstArgument.ssaName.name, number};
arguments.push_back(argument);
}
return success();
}
arguments.push_back(firstArgument);
return success();
}
inline ParseResult parseOneCompressedArgumentEntry(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::Argument>& arguments) {
OpAsmParser::Argument firstArgument;
if (parser.parseArgument(firstArgument))
return failure();
return parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments);
}
inline void applyArgumentTypes(ArrayRef<Type> inputTypes, SmallVectorImpl<OpAsmParser::Argument>& arguments) {
for (auto [argument, inputType] : llvm::zip_equal(arguments, inputTypes))
argument.type = inputType;
}
inline ParseResult parseArgumentBindings(OpAsmParser& parser,
SmallVectorImpl<OpAsmParser::Argument>& arguments,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
if (succeeded(parser.parseOptionalLParen())) {
if (succeeded(parser.parseOptionalRParen())) {
if (parser.parseEqual() || parseCompressedOperandList(parser, ListDelimiter::Paren, operands))
return failure();
return success();
}
OpAsmParser::Argument firstArgument;
if (parser.parseArgument(firstArgument) || parseCompressedArgumentEntryWithFirst(parser, firstArgument, arguments))
return failure();
while (succeeded(parser.parseOptionalComma()))
if (parseOneCompressedArgumentEntry(parser, arguments))
return failure();
if (parser.parseRParen() || parser.parseEqual()
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
return failure();
}
return success();
}
OpAsmParser::Argument argument;
if (parser.parseArgument(argument) || parser.parseEqual()
|| parseCompressedOperandList(parser, ListDelimiter::Paren, operands)) {
return failure();
}
arguments.push_back(argument);
return success();
}
} // namespace compact_asm
} // namespace onnx_mlir
#endif
+157
View File
@@ -0,0 +1,157 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "ConstantUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
static std::optional<int64_t> getIndexConstantValue(arith::ConstantOp constantOp) {
if (!constantOp.getType().isIndex())
return std::nullopt;
auto intAttr = dyn_cast<IntegerAttr>(constantOp.getValue());
if (!intAttr || !intAttr.getType().isIndex())
return std::nullopt;
return intAttr.getInt();
}
Block* getConstantInsertionBlock(Operation* anchorOp) {
assert(anchorOp && "expected a valid anchor operation");
if (auto funcOp = dyn_cast<func::FuncOp>(anchorOp))
return &funcOp.getBody().front();
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
return &funcOp.getBody().front();
if (auto moduleOp = dyn_cast<ModuleOp>(anchorOp))
return moduleOp.getBody();
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
return moduleOp.getBody();
return anchorOp->getBlock();
}
Value getOrCreateConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) {
assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getConstantInsertionBlock(anchorOp);
for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
continue;
return constantOp.getResult();
}
auto* arithDialect = anchorOp->getContext()->getOrLoadDialect<arith::ArithDialect>();
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
}
Value getOrCreateConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) {
assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getConstantInsertionBlock(anchorOp);
for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
continue;
return constantOp.getResult();
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(hostBlock);
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
}
Value getOrCreateConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) {
return getOrCreateConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType());
}
Value getOrCreateIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) {
Builder builder(anchorOp->getContext());
return getOrCreateConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
}
Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) {
Builder builder(anchorOp->getContext());
return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
}
void hoistAndUniquifyIndexConstants(func::FuncOp funcOp, RewriterBase& rewriter) {
if (funcOp.getBody().empty())
return;
Block& entryBlock = funcOp.getBody().front();
DenseMap<int64_t, Value> canonicalByValue;
SmallVector<arith::ConstantOp> constants;
funcOp.walk([&](arith::ConstantOp constantOp) {
if (!getIndexConstantValue(constantOp))
return;
constants.push_back(constantOp);
});
for (arith::ConstantOp constantOp : constants) {
auto value = getIndexConstantValue(constantOp);
if (!value || constantOp->getBlock() != &entryBlock)
continue;
canonicalByValue.try_emplace(*value, constantOp.getResult());
}
for (arith::ConstantOp constantOp : constants) {
auto value = getIndexConstantValue(constantOp);
if (!value)
continue;
Value canonical = canonicalByValue.lookup(*value);
if (!canonical) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&entryBlock);
Builder builder(funcOp.getContext());
canonical =
arith::ConstantOp::create(rewriter, constantOp.getLoc(), builder.getIndexType(), builder.getIndexAttr(*value));
canonicalByValue[*value] = canonical;
}
if (constantOp.getResult() == canonical)
continue;
constantOp.getResult().replaceAllUsesWith(canonical);
}
for (arith::ConstantOp constantOp : llvm::reverse(constants)) {
auto value = getIndexConstantValue(constantOp);
if (!value)
continue;
if (constantOp.getResult() == canonicalByValue.lookup(*value))
continue;
if (constantOp.use_empty())
rewriter.eraseOp(constantOp);
}
}
std::optional<int64_t> matchConstantIndexValue(Value value) {
if (!value || !value.getType().isIndex())
return std::nullopt;
if (auto constant = value.getDefiningOp<arith::ConstantIndexOp>())
return constant.value();
if (auto constant = value.getDefiningOp<arith::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constant.getValue()); intAttr && intAttr.getType().isIndex())
return intAttr.getInt();
return std::nullopt;
}
std::optional<int64_t> matchConstantIndexValue(OpFoldResult value) {
if (auto attr = dyn_cast<Attribute>(value))
if (auto intAttr = dyn_cast<IntegerAttr>(attr); intAttr && intAttr.getType().isIndex())
return intAttr.getInt();
if (auto operand = dyn_cast<Value>(value))
return matchConstantIndexValue(operand);
return std::nullopt;
}
} // namespace onnx_mlir
+33
View File
@@ -0,0 +1,33 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/FoldUtils.h"
#include <optional>
namespace onnx_mlir {
mlir::Block* getConstantInsertionBlock(mlir::Operation* anchorOp);
mlir::Value
getOrCreateConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, mlir::Attribute value, mlir::Type type);
mlir::Value
getOrCreateConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, mlir::Attribute value, mlir::Type type);
mlir::Value getOrCreateConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp);
mlir::Value getOrCreateIndexConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value);
mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value);
void hoistAndUniquifyIndexConstants(mlir::func::FuncOp funcOp, mlir::RewriterBase& rewriter);
std::optional<int64_t> matchConstantIndexValue(mlir::Value value);
std::optional<int64_t> matchConstantIndexValue(mlir::OpFoldResult value);
} // namespace onnx_mlir
+137
View File
@@ -0,0 +1,137 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
bool isCoreStaticAddressOp(mlir::Operation* op) {
if (mlir::isa<mlir::arith::ConstantOp,
mlir::arith::AddIOp,
mlir::arith::SubIOp,
mlir::arith::MulIOp,
mlir::arith::DivUIOp,
mlir::arith::DivSIOp,
mlir::arith::MinUIOp,
mlir::arith::RemUIOp,
mlir::arith::RemSIOp,
mlir::arith::IndexCastOp,
mlir::arith::CmpIOp,
mlir::memref::AllocOp,
mlir::memref::SubViewOp,
mlir::memref::CastOp,
mlir::memref::CollapseShapeOp,
mlir::memref::ExpandShapeOp>(op))
return true;
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(op))
return selectOp.getType().isIntOrIndex();
return false;
}
mlir::LogicalResult
walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
bool hasFailure = false;
for (mlir::Operation& op : block) {
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
continue;
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
mlir::Block& loopBody = forOp.getRegion().front();
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
auto step = resolveIndexValue(forOp.getStep(), knowledge);
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
hasFailure = true;
continue;
}
llvm::SmallVector<mlir::Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
StaticValueKnowledge loopKnowledge = knowledge;
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
loopKnowledge.aliases[iterArg] = iterValue;
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
hasFailure = true;
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(loopBody.getTerminator());
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
}
continue;
}
if (failed(callback(op, knowledge)))
hasFailure = true;
}
return mlir::success(!hasFailure);
}
mlir::LogicalResult walkPimCoreBlockStructurally(
mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
bool hasFailure = false;
for (mlir::Operation& op : block) {
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
continue;
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
mlir::Block& loopBody = forOp.getRegion().front();
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
auto step = resolveIndexValue(forOp.getStep(), knowledge);
if (failed(lowerBound) || failed(upperBound) || failed(step)) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM verification");
hasFailure = true;
continue;
}
if (*step <= 0) {
forOp.emitOpError("requires positive scf.for step for PIM verification");
hasFailure = true;
continue;
}
llvm::SmallVector<int64_t, 2> samples;
if (*lowerBound < *upperBound) {
samples.push_back(*lowerBound);
int64_t last = *lowerBound + ((*upperBound - 1 - *lowerBound) / *step) * *step;
if (last != *lowerBound)
samples.push_back(last);
}
for (int64_t inductionValue : samples) {
StaticValueKnowledge loopKnowledge = knowledge;
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), forOp.getInitArgs()))
loopKnowledge.aliases[iterArg] = iterValue;
if (failed(walkPimCoreBlockStructurally(loopBody, loopKnowledge, callback)))
hasFailure = true;
}
continue;
}
if (failed(callback(op, knowledge)))
hasFailure = true;
}
return mlir::success(!hasFailure);
}
} // namespace onnx_mlir
+32
View File
@@ -0,0 +1,32 @@
#pragma once
#include "mlir/IR/Block.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
namespace onnx_mlir {
/// Returns true for ops in a `pim.core` body that only participate in static
/// address or index computation and therefore do not emit PIM instructions.
bool isCoreStaticAddressOp(mlir::Operation* op);
/// Walks a `pim.core` body, statically unrolling nested `scf.for` loops when
/// their bounds are known and invoking `callback` only on instruction-emitting
/// operations.
mlir::LogicalResult
walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
/// Walks a `pim.core`-like body structurally for verification without
/// enumerating full loop trip counts. Loop bounds must still be statically
/// evaluable so address resolution remains well-defined.
mlir::LogicalResult walkPimCoreBlockStructurally(
mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
} // namespace onnx_mlir
+45
View File
@@ -0,0 +1,45 @@
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp) {
if (!moduleOp)
return mlir::failure();
llvm::SmallVector<mlir::ONNXEntryPointOp> entryPoints(moduleOp.getOps<mlir::ONNXEntryPointOp>());
if (entryPoints.size() > 1) {
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
return mlir::failure();
}
if (!entryPoints.empty()) {
auto entryPointAttr =
entryPoints.front()->getAttrOfType<mlir::SymbolRefAttr>(mlir::ONNXEntryPointOp::getEntryPointFuncAttrName());
if (!entryPointAttr) {
entryPoints.front().emitOpError("is missing the entry point function attribute");
return mlir::failure();
}
auto entryFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>(entryPointAttr.getLeafReference().getValue());
if (!entryFunc) {
entryPoints.front().emitOpError("references an unknown entry function ")
<< entryPointAttr.getLeafReference().getValue();
return mlir::failure();
}
return entryFunc;
}
if (auto mainGraphFunc = moduleOp.lookupSymbol<mlir::func::FuncOp>("main_graph"))
return mainGraphFunc;
llvm::SmallVector<mlir::func::FuncOp> nonExternalFuncs;
for (auto funcOp : moduleOp.getOps<mlir::func::FuncOp>())
if (!funcOp.isExternal())
nonExternalFuncs.push_back(funcOp);
if (nonExternalFuncs.size() == 1)
return nonExternalFuncs.front();
moduleOp.emitError("could not resolve a unique PIM entry function");
return mlir::failure();
}
} // namespace onnx_mlir
+13
View File
@@ -0,0 +1,13 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
namespace onnx_mlir {
/// Resolves the function the PIM pipeline should treat as its entry point.
/// Prefers ONNX entry-point metadata, then `main_graph`, then the only
/// non-external function if the module is otherwise unambiguous.
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
} // namespace onnx_mlir
+96
View File
@@ -0,0 +1,96 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "llvm/Support/MathExtras.h"
#include <optional>
#include "ConstantUtils.hpp"
#include "LoopUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static std::optional<int64_t> getStaticTripCount(Value lowerBound, Value upperBound, Value step) {
auto lower = matchConstantIndexValue(lowerBound);
auto upper = matchConstantIndexValue(upperBound);
auto stepValue = matchConstantIndexValue(step);
if (!lower || !upper || !stepValue)
return std::nullopt;
if (*stepValue <= 0)
return std::nullopt;
if (*upper <= *lower)
return int64_t {0};
return llvm::divideCeil(*upper - *lower, *stepValue);
}
} // namespace
static LogicalResult validateNormalizedLoopYields(Location loc, ValueRange initArgs, ArrayRef<Value> yieldedValues) {
if (yieldedValues.size() == initArgs.size())
return success();
emitError(loc) << "normalized loop body yielded " << yieldedValues.size() << " values for " << initArgs.size()
<< " iter args";
return failure();
}
FailureOr<NormalizedLoopResult> buildNormalizedScfFor(OpBuilder& builder,
Location loc,
Value lowerBound,
Value upperBound,
Value step,
ValueRange initArgs,
NormalizedLoopBodyBuilder bodyBuilder) {
NormalizedLoopResult result;
if (auto stepValue = matchConstantIndexValue(step); stepValue && *stepValue <= 0) {
emitError(loc) << "normalized scf.for requires a positive step, got " << *stepValue;
return failure();
}
if (auto tripCount = getStaticTripCount(lowerBound, upperBound, step)) {
if (*tripCount == 0) {
llvm::append_range(result.results, initArgs);
return result;
}
if (*tripCount == 1) {
result.inductionVar = lowerBound;
if (failed(bodyBuilder(builder, loc, lowerBound, initArgs, result.results)))
return failure();
if (failed(validateNormalizedLoopYields(loc, initArgs, result.results)))
return failure();
return result;
}
}
result.loop = scf::ForOp::create(builder, loc, lowerBound, upperBound, step, initArgs);
result.inductionVar = result.loop.getInductionVar();
{
OpBuilder::InsertionGuard guard(builder);
Block* body = result.loop.getBody();
if (!body->empty())
if (auto yieldOp = dyn_cast<scf::YieldOp>(body->back()))
yieldOp->erase();
builder.setInsertionPointToEnd(body);
ValueRange iterArgs = result.loop.getRegionIterArgs();
if (failed(bodyBuilder(builder, loc, result.inductionVar, iterArgs, result.results))) {
result.loop.erase();
return failure();
}
if (failed(validateNormalizedLoopYields(loc, initArgs, result.results))) {
result.loop.erase();
return failure();
}
scf::YieldOp::create(builder, loc, result.results);
}
builder.setInsertionPointAfter(result.loop);
result.results.assign(result.loop.getResults().begin(), result.loop.getResults().end());
return result;
}
} // namespace onnx_mlir
+30
View File
@@ -0,0 +1,30 @@
#pragma once
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
namespace onnx_mlir {
struct NormalizedLoopResult {
mlir::Value inductionVar;
llvm::SmallVector<mlir::Value, 4> results;
mlir::scf::ForOp loop;
bool wasInlined() const { return !loop; }
};
using NormalizedLoopBodyBuilder = llvm::function_ref<mlir::LogicalResult(
mlir::OpBuilder&, mlir::Location, mlir::Value, mlir::ValueRange, llvm::SmallVectorImpl<mlir::Value>&)>;
mlir::FailureOr<NormalizedLoopResult> buildNormalizedScfFor(mlir::OpBuilder& builder,
mlir::Location loc,
mlir::Value lowerBound,
mlir::Value upperBound,
mlir::Value step,
mlir::ValueRange initArgs,
NormalizedLoopBodyBuilder bodyBuilder);
} // namespace onnx_mlir
+166
View File
@@ -0,0 +1,166 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape) {
llvm::SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides) {
llvm::SmallVector<int64_t> indices(shape.size(), 0);
for (auto [dim, stride] : llvm::enumerate(strides)) {
indices[dim] = linearIndex / stride;
linearIndex %= stride;
}
return indices;
}
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides) {
int64_t linearIndex = 0;
for (auto [index, stride] : llvm::zip_equal(indices, strides))
linearIndex += index * stride;
return linearIndex;
}
int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
int64_t numElements = 1;
for (int64_t dim : shape)
numElements *= dim;
return numElements;
}
bool hasByteSizedElementType(mlir::Type elementType) {
if (mlir::isa<mlir::IndexType>(elementType))
return true;
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
return intType.getWidth() > 0 && intType.getWidth() % 8 == 0;
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
return floatType.getWidth() > 0 && floatType.getWidth() % 8 == 0;
return false;
}
size_t getElementTypeSizeInBytes(mlir::Type elementType) {
if (mlir::isa<mlir::IndexType>(elementType))
return mlir::IndexType::kInternalStorageBitWidth / 8;
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
return static_cast<size_t>(intType.getWidth() / 8);
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
return static_cast<size_t>(floatType.getWidth() / 8);
llvm_unreachable("expected byte-sized integer, float, or index element type");
}
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType) {
return static_cast<size_t>(shapedType.getNumElements()) * getElementTypeSizeInBytes(shapedType.getElementType());
}
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides) {
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _dimension] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
llvm::ArrayRef<int64_t> staticSizes,
llvm::ArrayRef<int64_t> staticStrides) {
if (sourceShape.size() != mixedOffsets.size() || sourceShape.size() != staticSizes.size()
|| sourceShape.size() != staticStrides.size()) {
return false;
}
if (llvm::any_of(staticStrides, [](int64_t stride) { return stride != 1; }))
return false;
auto reversedTriples =
llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(mixedOffsets), llvm::reverse(staticSizes));
auto firstNonZeroOrDynamicOffset = llvm::find_if(reversedTriples, [](auto triple) {
auto [_sourceDim, offset, _size] = triple;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
return mlir::cast<mlir::IntegerAttr>(attr).getInt() != 0;
return true;
});
if (firstNonZeroOrDynamicOffset != reversedTriples.end()) {
auto [sourceDim, offset, size] = *firstNonZeroOrDynamicOffset;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
int64_t staticOffset = mlir::cast<mlir::IntegerAttr>(attr).getInt();
if (size > sourceDim - staticOffset)
return false;
}
++firstNonZeroOrDynamicOffset;
for (auto it = firstNonZeroOrDynamicOffset; it != reversedTriples.end(); ++it)
if (std::get<2>(*it) != 1)
return false;
}
auto reversedSizes = llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(staticSizes));
auto firstDifferentSize = llvm::find_if(reversedSizes, [](auto pair) {
auto [sourceDim, size] = pair;
return size != sourceDim;
});
if (firstDifferentSize != reversedSizes.end()) {
++firstDifferentSize;
for (auto it = firstDifferentSize; it != reversedSizes.end(); ++it)
if (std::get<1>(*it) != 1)
return false;
}
return true;
}
} // namespace onnx_mlir
+39
View File
@@ -0,0 +1,39 @@
#pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool hasByteSizedElementType(mlir::Type elementType);
size_t getElementTypeSizeInBytes(mlir::Type elementType);
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
llvm::ArrayRef<int64_t> staticSizes,
llvm::ArrayRef<int64_t> staticStrides);
} // namespace onnx_mlir
+106
View File
@@ -0,0 +1,106 @@
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
namespace onnx_mlir {
Value stripMemRefCasts(Value value) {
while (auto castOp = value.getDefiningOp<memref::CastOp>())
value = castOp.getSource();
return value;
}
Value stripMemRefViewOps(Value value) {
while (true) {
if (auto castOp = value.getDefiningOp<memref::CastOp>()) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = value.getDefiningOp<memref::CollapseShapeOp>()) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = value.getDefiningOp<memref::ExpandShapeOp>()) {
value = expandOp.getSrc();
continue;
}
return value;
}
}
Value stripMemRefAddressingOps(Value value) {
while (true) {
if (auto subviewOp = value.getDefiningOp<memref::SubViewOp>()) {
value = subviewOp.getSource();
continue;
}
Value strippedValue = stripMemRefViewOps(value);
if (strippedValue == value)
return value;
value = strippedValue;
}
}
bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticStrides(), [](int64_t value) { return !ShapedType::isDynamic(value); });
}
FailureOr<StaticSubviewInfo> getStaticSubviewInfo(Value value) {
value = stripMemRefViewOps(value);
auto subviewOp = value.getDefiningOp<memref::SubViewOp>();
if (!subviewOp)
return failure();
auto source = stripMemRefCasts(subviewOp.getSource());
auto sourceType = dyn_cast<MemRefType>(source.getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
StaticSubviewInfo info;
info.source = source;
info.sourceShape.assign(sourceType.getShape().begin(), sourceType.getShape().end());
SmallVector<OpFoldResult> mixedOffsets = subviewOp.getMixedOffsets();
info.offsets.assign(mixedOffsets.begin(), mixedOffsets.end());
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto staticSize = getConstantIntValue(size);
if (!staticSize)
return failure();
info.sizes.push_back(*staticSize);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto staticStride = getConstantIntValue(stride);
if (!staticStride)
return failure();
info.strides.push_back(*staticStride);
}
return info;
}
FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info) {
SmallVector<int64_t> staticOffsets;
staticOffsets.reserve(info.offsets.size());
for (OpFoldResult offset : info.offsets) {
auto staticOffset = getConstantIntValue(offset);
if (!staticOffset)
return failure();
staticOffsets.push_back(*staticOffset);
}
return staticOffsets;
}
bool isMemRefBaseAddressableValue(Value value) {
value = stripMemRefAddressingOps(value);
if (isa<BlockArgument>(value))
return true;
Operation* defOp = value.getDefiningOp();
return defOp && isa<memref::AllocOp, memref::GetGlobalOp>(defOp);
}
} // namespace onnx_mlir
+34
View File
@@ -0,0 +1,34 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/LogicalResult.h"
namespace onnx_mlir {
struct StaticSubviewInfo {
mlir::Value source;
llvm::SmallVector<int64_t> sourceShape;
llvm::SmallVector<mlir::OpFoldResult> offsets;
llvm::SmallVector<int64_t> sizes;
llvm::SmallVector<int64_t> strides;
};
mlir::Value stripMemRefCasts(mlir::Value value);
mlir::Value stripMemRefViewOps(mlir::Value value);
mlir::Value stripMemRefAddressingOps(mlir::Value value);
bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview);
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
bool isMemRefBaseAddressableValue(mlir::Value value);
} // namespace onnx_mlir
+316
View File
@@ -0,0 +1,316 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallSet.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
bool hasWeightAlways(mlir::Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
void markWeightAlways(mlir::Operation* op) {
assert(op && "expected valid op");
op->setAttr(PimWeightAlwaysAttrName, mlir::UnitAttr::get(op->getContext()));
}
namespace {
CompiledIndexExpr makeConstantExpr(int64_t constant) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = constant;
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::move(expr)));
}
CompiledIndexExpr makeBinaryExpr(CompiledIndexExprNode::Kind kind, CompiledIndexExpr lhs, CompiledIndexExpr rhs) {
CompiledIndexExprNode expr;
expr.kind = kind;
expr.operands = {std::move(lhs), std::move(rhs)};
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::move(expr)));
}
CompiledIndexExpr addExpr(CompiledIndexExpr lhs, CompiledIndexExpr rhs) {
return makeBinaryExpr(CompiledIndexExprNode::Kind::Add, std::move(lhs), std::move(rhs));
}
CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
}
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefTypeStrides(mlir::MemRefType type) {
llvm::SmallVector<int64_t> strides;
int64_t offset = 0;
if (failed(type.getStridesAndOffset(strides, offset)))
return mlir::failure();
if (llvm::is_contained(strides, mlir::ShapedType::kDynamic))
return mlir::failure();
return strides;
}
template <typename VMMOpTy, typename ParentOpTy>
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
auto weightArg = parentOp.getWeightArgument(weightIndex);
if (!weightArg)
return false;
bool found = false;
parentOp.walk([&](mlir::Operation* op) {
if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeight() == *weightArg;
});
return found;
}
template <typename VMMOpTy, typename ParentOpTy>
void walkVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited;
auto walkWeight = [&](mlir::Value weight) {
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
auto weightArg = parentOp.getWeightArgument(weightIndex);
if (!weightArg || *weightArg != weight)
continue;
if (visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
break;
}
};
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
}
} // namespace
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
mlir::Operation* user = use.getOwner();
unsigned operandIndex = use.getOperandNumber();
auto computeOp = mlir::dyn_cast<spatial::SpatCompute>(user);
if (!computeOp || operandIndex >= computeOp.getWeights().size())
return false;
return hasVmmWeightUse<spatial::SpatVMMOp>(computeOp, operandIndex);
}
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
llvm::SmallPtrSet<mlir::Value, 8> visited;
auto walkUses = [&](mlir::Value currentValue, auto& self) -> bool {
if (!visited.insert(currentValue).second)
return true;
if (currentValue.use_empty())
return false;
return llvm::all_of(currentValue.getUses(), [&](mlir::OpOperand& use) {
if (isSpatialMvmVmmWeightUse(use))
return true;
mlir::Operation* user = use.getOwner();
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(user))
return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self);
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(user))
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
if (auto transposeOp = mlir::dyn_cast<mlir::linalg::TransposeOp>(user))
return transposeOp.getInput() == currentValue && self(transposeOp.getResult()[0], self);
return false;
});
};
return walkUses(value, walkUses);
}
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback) {
assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) {
coreOp.walk([&](pim::PimVMMOp vmmOp) {
if (auto weightIndex = resolveWeightIndex(coreOp.getOperation(), vmmOp.getWeight()))
callback(coreOp->getOpOperand(*weightIndex));
});
});
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
if (auto weightIndex = resolveWeightIndex(coreBatchOp.getOperation(), vmmOp.getWeight()))
callback(coreBatchOp->getOpOperand(*weightIndex));
});
});
}
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) {
weight = stripMemRefAddressingOps(weight);
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
if (coreOp.getWeightArgument(weightIndex) == weight)
return weightIndex;
return std::nullopt;
}
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
if (coreBatchOp.getWeightArgument(weightIndex) == weight)
return weightIndex;
return std::nullopt;
}
return std::nullopt;
}
llvm::FailureOr<ResolvedWeightView>
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge) {
llvm::SmallVector<mlir::Operation*> viewOps;
mlir::Value current = weight;
while (true) {
if (mlir::Value directAlias = knowledge.aliases.lookup(current); directAlias && directAlias != current) {
current = directAlias;
continue;
}
if (auto defOp = current.getDefiningOp()) {
if (auto getGlobalOp = mlir::dyn_cast<mlir::memref::GetGlobalOp>(defOp)) {
auto moduleOp = weightOwner ? weightOwner->getParentOfType<mlir::ModuleOp>() : mlir::ModuleOp {};
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return mlir::failure();
ResolvedWeightView view;
view.globalOp = globalOp;
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStrides(view.shape);
CompiledIndexExpr offsetExpr = makeConstantExpr(0);
for (mlir::Operation* viewOp : llvm::reverse(viewOps)) {
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(viewOp)) {
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) {
CompiledIndexExpr offsetValue = makeConstantExpr(0);
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
if (!intAttr)
return mlir::failure();
offsetValue = makeConstantExpr(intAttr.getInt());
}
else if (auto value = mlir::dyn_cast<mlir::Value>(offset)) {
auto compiledOffset = compileIndexExpr(value);
if (failed(compiledOffset))
return mlir::failure();
offsetValue = *compiledOffset;
}
else {
return mlir::failure();
}
offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride));
}
auto resultType = mlir::cast<mlir::MemRefType>(subview.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = std::move(*resultStrides);
continue;
}
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(viewOp)) {
auto resultType = mlir::cast<mlir::MemRefType>(collapse.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = std::move(*resultStrides);
continue;
}
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(viewOp)) {
auto resultType = mlir::cast<mlir::MemRefType>(expand.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = std::move(*resultStrides);
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(viewOp)) {
auto resultType = mlir::cast<mlir::MemRefType>(castOp.getResult().getType());
auto resultStrides = getStaticMemRefTypeStrides(resultType);
if (failed(resultStrides))
return mlir::failure();
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = std::move(*resultStrides);
continue;
}
return mlir::failure();
}
auto resolvedOffset = offsetExpr.evaluate(knowledge);
if (failed(resolvedOffset))
return mlir::failure();
view.offset = *resolvedOffset;
return view;
}
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp)) {
viewOps.push_back(defOp);
current = subview.getSource();
continue;
}
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp)) {
viewOps.push_back(defOp);
current = collapse.getSrc();
continue;
}
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(defOp)) {
viewOps.push_back(defOp);
current = expand.getSrc();
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(defOp)) {
viewOps.push_back(defOp);
current = castOp.getSource();
continue;
}
return mlir::failure();
}
if (mlir::Value loopAlias = resolveLoopCarriedAlias(current, knowledge); loopAlias && loopAlias != current) {
current = loopAlias;
continue;
}
auto weightIndex = resolveWeightIndex(weightOwner, current);
if (!weightIndex)
return mlir::failure();
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
current = coreOp.getWeights()[*weightIndex];
continue;
}
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
current = coreBatchOp.getWeights()[*weightIndex];
continue;
}
return mlir::failure();
}
}
} // namespace onnx_mlir
+64
View File
@@ -0,0 +1,64 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <optional>
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {
struct ResolvedWeightView {
mlir::memref::GlobalOp globalOp;
llvm::SmallVector<int64_t> shape;
llvm::SmallVector<int64_t> strides;
int64_t offset = 0;
bool operator==(const ResolvedWeightView& other) const {
return globalOp == other.globalOp && shape == other.shape && strides == other.strides && offset == other.offset;
}
};
bool hasWeightAlways(mlir::Operation* op);
/// Tags an op as producing a value that should stay materialized as a reusable
/// weight across later PIM lowering/codegen stages.
void markWeightAlways(mlir::Operation* op);
bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use);
/// Returns true when a value flows only into Spatial weighted MVM/VMM operands,
/// allowing later passes to preserve it as a dedicated weight-like object.
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
/// Visits weight operands consumed by Pim core ops/core batches so downstream
/// passes can identify globals that must remain weight-backed.
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight);
llvm::FailureOr<ResolvedWeightView>
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge = {});
template <typename CoreLikeOpTy>
llvm::SmallVector<unsigned, 8> getUsedWeightIndices(CoreLikeOpTy coreLikeOp) {
llvm::SmallVector<unsigned, 8> indices;
coreLikeOp.walk([&](pim::PimVMMOp vmmOp) {
auto weightIndex = resolveWeightIndex(coreLikeOp.getOperation(), vmmOp.getWeight());
if (weightIndex && !llvm::is_contained(indices, *weightIndex))
indices.push_back(*weightIndex);
});
llvm::sort(indices);
return indices;
}
} // namespace onnx_mlir
-546
View File
@@ -1,546 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/Support/raw_os_ostream.h"
#include <filesystem>
#include <fstream>
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
std::string getOutputDir() {
if (outputBaseName.empty() || outputBaseName == "-")
return {};
size_t lastSlash = outputBaseName.find_last_of('/');
if (lastSlash == std::string::npos)
return ".";
return outputBaseName.substr(0, lastSlash);
}
void createDirectory(const std::string& directory) {
std::error_code errorCode;
std::filesystem::create_directories(directory, errorCode);
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
}
void dumpModule(ModuleOp moduleOp, const std::string& name) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return;
std::string dialectsDir = outputDir + "/dialects";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
os << *moduleOp;
os.flush();
file.close();
}
FailureOr<func::FuncOp> getPimEntryFunc(ModuleOp moduleOp) {
if (!moduleOp)
return failure();
SmallVector<ONNXEntryPointOp> entryPoints(moduleOp.getOps<ONNXEntryPointOp>());
if (entryPoints.size() > 1) {
moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size();
return failure();
}
if (!entryPoints.empty()) {
auto entryPointAttr =
entryPoints.front()->getAttrOfType<SymbolRefAttr>(ONNXEntryPointOp::getEntryPointFuncAttrName());
if (!entryPointAttr) {
entryPoints.front().emitOpError("is missing the entry point function attribute");
return failure();
}
auto entryFunc = moduleOp.lookupSymbol<func::FuncOp>(entryPointAttr.getLeafReference().getValue());
if (!entryFunc) {
entryPoints.front().emitOpError("references an unknown entry function ")
<< entryPointAttr.getLeafReference().getValue();
return failure();
}
return entryFunc;
}
if (auto mainGraphFunc = moduleOp.lookupSymbol<func::FuncOp>("main_graph"))
return mainGraphFunc;
SmallVector<func::FuncOp> nonExternalFuncs;
for (auto funcOp : moduleOp.getOps<func::FuncOp>())
if (!funcOp.isExternal())
nonExternalFuncs.push_back(funcOp);
if (nonExternalFuncs.size() == 1)
return nonExternalFuncs.front();
moduleOp.emitError("could not resolve a unique PIM entry function");
return failure();
}
bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; }
void markWeightAlways(Operation* op) {
assert(op && "expected valid op");
op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext()));
}
memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) {
if (!moduleOp || !getGlobalOp)
return {};
return moduleOp.lookupSymbol<memref::GlobalOp>(getGlobalOp.getName());
}
FailureOr<Operation*> getOtherEndOfChannel(Operation* op, bool opIsReceive, RewriterBase& rewriter) {
auto channelNewOp = op->getOperand(0).getDefiningOp<spatial::SpatChannelNewOp>();
if (!channelNewOp) {
op->emitError("User of Channel must have the first operand created by ChannelNewOp.");
return failure();
}
// channelNewOp should have two users: `op` and a
// `ChannelSendOp`/`ChannelReceiveOp`
auto channelUsers = channelNewOp->getUsers();
auto usersIterator = channelUsers.begin();
auto firstUser = *usersIterator;
usersIterator++;
if (usersIterator == channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"only one found.");
channelNewOp->dump();
op->dump();
channelNewOp->getParentOp()->dump();
return failure();
}
auto secondUser = *usersIterator;
usersIterator++;
if (usersIterator != channelUsers.end()) {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"more than two found.");
return failure();
}
Operation* notOpUser;
if (firstUser == op) {
notOpUser = secondUser;
}
else if (secondUser == op) {
notOpUser = firstUser;
}
else {
op->emitError("Operand generated by ChannelNewOp must have two users, "
"and one of them must be me, but"
"none of them is actually me.");
return failure();
}
if (opIsReceive) {
if (!isa<spatial::SpatChannelSendOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelSendOp.");
return failure();
}
return notOpUser;
}
else {
if (!isa<spatial::SpatChannelReceiveOp>(notOpUser)) {
op->emitError("Operand generated by ChannelNewOp has two user, one is "
"me, the other is not a ChannelReceiveOp.");
return failure();
}
return notOpUser;
}
}
SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
SmallVector<int64_t> delinearizeIndex(int64_t linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides) {
SmallVector<int64_t> indices(shape.size(), 0);
for (auto [dim, stride] : llvm::enumerate(strides)) {
indices[dim] = linearIndex / stride;
linearIndex %= stride;
}
return indices;
}
int64_t linearizeIndex(ArrayRef<int64_t> indices, ArrayRef<int64_t> strides) {
int64_t linearIndex = 0;
for (auto [index, stride] : llvm::zip_equal(indices, strides))
linearIndex += index * stride;
return linearIndex;
}
int64_t getNumElements(ArrayRef<int64_t> shape) {
int64_t numElements = 1;
for (int64_t dim : shape)
numElements *= dim;
return numElements;
}
bool isMemoryContiguous(ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> offsets,
ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; }))
return false;
auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()),
llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstNonZeroOffset = std::find_if(
offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool {
auto [offset, _size, _dimension] = offsetAndSizeAndShape;
return offset != 0;
});
if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) {
auto [offset, size, dimension] = *firstNonZeroOffset;
if (size > dimension - offset)
return false;
++firstNonZeroOffset;
if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool {
auto [_offset, size, _dimension] = offsetAndSizeAndShape;
return size != 1;
}))
return false;
}
auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()),
llvm::make_range(srcShape.rbegin(), srcShape.rend()));
auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool {
auto [size, dimension] = sizeAndShape;
return size != dimension;
});
if (firstDifferentSize != sizesAndShape.end()) {
++firstDifferentSize;
if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool {
auto [size, _dimension] = sizeAndShape;
return size != 1;
}))
return false;
}
return true;
}
static Value resolveAlias(Value value, const StaticValueKnowledge* knowledge) {
if (!knowledge)
return value;
auto iter = knowledge->aliases.find(value);
while (iter != knowledge->aliases.end()) {
value = iter->second;
iter = knowledge->aliases.find(value);
}
return value;
}
// Walks through view-like ops and DPS tied operands to find the "underlying" memref value
// behind an scf.for iter-arg. Used both when resolving a contiguous address inside a loop
// and when propagating yielded values across iterations during static unrolling.
static Value resolveLoopCarriedAliasImpl(Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (auto blockArgument = dyn_cast<BlockArgument>(value))
return value;
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return value;
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
if (auto result = dyn_cast<OpResult>(value))
if (OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result))
return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge);
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp))
return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge);
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge);
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp))
return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge);
return value;
}
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge);
static FailureOr<int64_t> resolveIndexValueImpl(Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
if (knowledge) {
auto iter = knowledge->indexValues.find(value);
if (iter != knowledge->indexValues.end())
return iter->second;
}
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
if (constantOp) {
if (auto integerAttr = dyn_cast<IntegerAttr>(constantOp.getValue()))
return integerAttr.getInt();
}
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (auto indexCastOp = dyn_cast<arith::IndexCastOp>(definingOp))
return resolveIndexValueImpl(indexCastOp.getIn(), knowledge);
if (auto addOp = dyn_cast<arith::AddIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs + *rhs;
}
if (auto subOp = dyn_cast<arith::SubIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs - *rhs;
}
if (auto mulOp = dyn_cast<arith::MulIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return failure();
return *lhs * *rhs;
}
if (auto divOp = dyn_cast<arith::DivUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
}
if (auto remOp = dyn_cast<arith::RemUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
}
return failure();
}
static FailureOr<int64_t> resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge) {
if (auto attr = dyn_cast<Attribute>(ofr)) {
auto integerAttr = dyn_cast<IntegerAttr>(attr);
if (!integerAttr)
return failure();
return integerAttr.getInt();
}
return resolveIndexValueImpl(cast<Value>(ofr), knowledge);
}
static FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(Value value,
const StaticValueKnowledge* knowledge) {
int64_t byteOffset = 0;
value = resolveAlias(value, knowledge);
while (true) {
if (isa<BlockArgument>(value))
return ResolvedContiguousAddress {value, byteOffset};
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (auto dpsDefiningOp = dyn_cast<DestinationStyleOpInterface>(definingOp)) {
OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast<OpResult>(value));
if (!tiedOperand)
return failure();
value = resolveAlias(tiedOperand->get(), knowledge);
continue;
}
if (auto forOp = dyn_cast<scf::ForOp>(definingOp)) {
auto result = dyn_cast<OpResult>(value);
if (!result)
return failure();
// Trace the loop carry back to its underlying memref, then if that memref is the
// loop's own iter-arg we know the base comes from the corresponding init arg
// (every iteration yields the same backing memory in the DPS sense).
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge);
if (auto blockArgument = dyn_cast<BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge);
continue;
}
}
value = yieldedValue;
continue;
}
if (auto subviewOp = dyn_cast<memref::SubViewOp>(definingOp)) {
auto sourceType = dyn_cast<MemRefType>(subviewOp.getSource().getType());
auto subviewType = dyn_cast<MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return failure();
SmallVector<int64_t> offsets;
SmallVector<int64_t> sizes;
SmallVector<int64_t> strides;
offsets.reserve(subviewOp.getMixedOffsets().size());
sizes.reserve(subviewOp.getMixedSizes().size());
strides.reserve(subviewOp.getMixedStrides().size());
for (OpFoldResult offset : subviewOp.getMixedOffsets()) {
auto resolvedOffset = resolveOpFoldResult(offset, knowledge);
if (failed(resolvedOffset))
return failure();
offsets.push_back(*resolvedOffset);
}
for (OpFoldResult size : subviewOp.getMixedSizes()) {
auto resolvedSize = resolveOpFoldResult(size, knowledge);
if (failed(resolvedSize))
return failure();
sizes.push_back(*resolvedSize);
}
for (OpFoldResult stride : subviewOp.getMixedStrides()) {
auto resolvedStride = resolveOpFoldResult(stride, knowledge);
if (failed(resolvedStride))
return failure();
strides.push_back(*resolvedStride);
}
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
value = resolveAlias(subviewOp.getSource(), knowledge);
continue;
}
if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
value = resolveAlias(castOp.getSource(), knowledge);
continue;
}
if (auto collapseOp = dyn_cast<memref::CollapseShapeOp>(definingOp)) {
value = resolveAlias(collapseOp.getSrc(), knowledge);
continue;
}
if (auto expandOp = dyn_cast<memref::ExpandShapeOp>(definingOp)) {
value = resolveAlias(expandOp.getSrc(), knowledge);
continue;
}
if (isa<memref::AllocOp, memref::GetGlobalOp>(definingOp))
return ResolvedContiguousAddress {value, byteOffset};
return failure();
}
}
FailureOr<int64_t> resolveIndexValue(Value value) { return resolveIndexValueImpl(value, nullptr); }
FailureOr<int64_t> resolveIndexValue(Value value, const StaticValueKnowledge& knowledge) {
return resolveIndexValueImpl(value, &knowledge);
}
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value) {
return resolveContiguousAddressImpl(value, nullptr);
}
FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(Value value, const StaticValueKnowledge& knowledge) {
return resolveContiguousAddressImpl(value, &knowledge);
}
Value resolveLoopCarriedAlias(Value value, const StaticValueKnowledge& knowledge) {
return resolveLoopCarriedAliasImpl(value, &knowledge);
}
bool isCoreStaticAddressOp(Operation* op) {
return isa<arith::ConstantOp,
arith::AddIOp,
arith::SubIOp,
arith::MulIOp,
arith::DivUIOp,
arith::RemUIOp,
arith::IndexCastOp,
memref::AllocOp,
memref::SubViewOp,
memref::CastOp,
memref::CollapseShapeOp,
memref::ExpandShapeOp>(op);
}
LogicalResult walkPimCoreBlock(Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<LogicalResult(Operation&, const StaticValueKnowledge&)> callback) {
bool hasFailure = false;
for (Operation& op : block) {
if (isa<pim::PimHaltOp, scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
Block& loopBody = forOp.getRegion().front();
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
auto step = resolveIndexValue(forOp.getStep(), knowledge);
if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen");
hasFailure = true;
continue;
}
SmallVector<Value> iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end());
for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) {
StaticValueKnowledge loopKnowledge = knowledge;
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues))
loopKnowledge.aliases[iterArg] = iterValue;
if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback)))
hasFailure = true;
auto yieldOp = cast<scf::YieldOp>(loopBody.getTerminator());
for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands()))
iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge);
}
continue;
}
if (failed(callback(op, knowledge)))
hasFailure = true;
}
return success(!hasFailure);
}
} // namespace onnx_mlir
+11 -70
View File
@@ -7,82 +7,23 @@
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
#include "src/Compiler/CompilerOptions.hpp"
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {
struct ResolvedContiguousAddress {
mlir::Value base;
int64_t byteOffset = 0;
};
struct StaticValueKnowledge {
llvm::DenseMap<mlir::Value, int64_t> indexValues;
llvm::DenseMap<mlir::Value, mlir::Value> aliases;
StaticValueKnowledge() {}
};
std::string getOutputDir();
void createDirectory(const std::string& directory);
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
llvm::FailureOr<mlir::func::FuncOp> getPimEntryFunc(mlir::ModuleOp moduleOp);
bool hasWeightAlways(mlir::Operation* op);
void markWeightAlways(mlir::Operation* op);
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
llvm::FailureOr<mlir::Operation*>
getOtherEndOfChannel(mlir::Operation* op, bool opIsReceive, mlir::RewriterBase& rewriter);
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t>
delinearizeIndex(int64_t linearIndex, llvm::ArrayRef<int64_t> shape, llvm::ArrayRef<int64_t> strides);
int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t> strides);
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
/// Follows alias and view/DPS chains using `knowledge` to find the value an scf.for
/// iter-arg is ultimately backed by. Used when interpreting scf.for loop carries.
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
/// Returns true for ops inside a pim.core body that do not emit any PIM instruction and
/// only contribute to static addressing or index computations (arith integer math,
/// memref view ops, memref.alloc, arith.constant).
bool isCoreStaticAddressOp(mlir::Operation* op);
/// Walks `block` (the body of a pim.core region or an scf.for nested in it), statically
/// unrolling any scf.for with resolvable bounds using `knowledge`. For each remaining op
/// that is not skipped (pim.halt, scf.yield, or isCoreStaticAddressOp), `callback` is
/// invoked with the op and the in-scope knowledge. The walker keeps going after a callback
/// failure so callers can collect multiple diagnostics, but propagates the overall result.
mlir::LogicalResult
walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
inline constexpr llvm::StringLiteral kCoreIdAttrName = "coreId";
inline constexpr llvm::StringLiteral kCoreIdsAttrName = "coreIds";
} // namespace onnx_mlir
@@ -0,0 +1,222 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include "CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir;
namespace onnx_mlir::pim {
namespace {
static void emitCrashMessage(llvm::StringRef fieldName, llvm::StringRef message) {
llvm::errs() << "PIM " << fieldName << " " << message << "\n";
}
template <typename To, typename From>
static FailureOr<To> checkedCastAtLocation(From value, Location loc, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<To> && std::is_integral_v<From>, "checkedCastAtLocation requires integral types");
using ToLimits = std::numeric_limits<To>;
if constexpr (std::is_signed_v<From> == std::is_signed_v<To>) {
if (value < static_cast<From>(ToLimits::min()) || value > static_cast<From>(ToLimits::max())) {
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
return failure();
}
}
else if constexpr (std::is_signed_v<From>) {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::make_unsigned_t<To>;
if (value < 0 || static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
return failure();
}
}
else {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::conditional_t<std::is_signed_v<To>, std::make_unsigned_t<To>, To>;
if (static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(loc, fieldName, "is outside representable range");
return failure();
}
}
return static_cast<To>(value);
}
template <typename UInt>
FailureOr<UInt> checkedMulAtLocation(UInt lhs, UInt rhs, Location loc, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>,
"checkedMulAtLocation requires unsigned integral types");
if (lhs != 0 && rhs > std::numeric_limits<UInt>::max() / lhs) {
emitCheckedArithmeticError(loc, fieldName, "multiplication overflow");
return failure();
}
return lhs * rhs;
}
} // namespace
InFlightDiagnostic emitCheckedArithmeticError(Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message) {
assert(anchor && "expected arithmetic diagnostics to have an anchor op");
return anchor->emitOpError() << fieldName << " " << message;
}
InFlightDiagnostic emitCheckedArithmeticError(Location loc, llvm::StringRef fieldName, llvm::StringRef message) {
return emitError(loc) << "PIM " << fieldName << " " << message;
}
FailureOr<int32_t> checkedI32(int64_t value, Operation* anchor, llvm::StringRef fieldName) {
return checkedCast<int32_t>(value, anchor, fieldName);
}
FailureOr<int32_t> checkedI32(uint64_t value, Operation* anchor, llvm::StringRef fieldName) {
return checkedCast<int32_t>(value, anchor, fieldName);
}
FailureOr<uint8_t> checkedU8(uint64_t value, Operation* anchor, llvm::StringRef fieldName) {
return checkedCast<uint8_t>(value, anchor, fieldName);
}
FailureOr<size_t> checkedSize(int64_t value, Operation* anchor, llvm::StringRef fieldName) {
return checkedCast<size_t>(value, anchor, fieldName);
}
FailureOr<IntegerAttr>
getCheckedI32Attr(Builder& builder, Operation* anchor, int64_t value, llvm::StringRef fieldName) {
assert(anchor && "checked op-based attrs require a non-null diagnostic anchor");
auto checkedValue = checkedI32(value, anchor, fieldName);
if (failed(checkedValue))
return failure();
return builder.getI32IntegerAttr(*checkedValue);
}
FailureOr<IntegerAttr>
getCheckedI32Attr(Builder& builder, Operation* anchor, uint64_t value, llvm::StringRef fieldName) {
assert(anchor && "checked op-based attrs require a non-null diagnostic anchor");
auto checkedValue = checkedI32(value, anchor, fieldName);
if (failed(checkedValue))
return failure();
return builder.getI32IntegerAttr(*checkedValue);
}
FailureOr<IntegerAttr> getCheckedI32Attr(Builder& builder, Location loc, int64_t value, llvm::StringRef fieldName) {
auto checkedValue = checkedCastAtLocation<int32_t>(value, loc, fieldName);
if (failed(checkedValue))
return failure();
return builder.getI32IntegerAttr(*checkedValue);
}
FailureOr<IntegerAttr> getCheckedI32Attr(Builder& builder, Location loc, uint64_t value, llvm::StringRef fieldName) {
auto checkedValue = checkedCastAtLocation<int32_t>(value, loc, fieldName);
if (failed(checkedValue))
return failure();
return builder.getI32IntegerAttr(*checkedValue);
}
FailureOr<uint64_t> getCheckedShapedTypeSizeInBytes(ShapedType type, Operation* anchor, llvm::StringRef fieldName) {
assert(anchor && "checked op-based size helpers require a non-null diagnostic anchor");
if (!type.hasStaticShape()) {
emitCheckedArithmeticError(anchor, fieldName, "requires static shaped type");
return failure();
}
if (!hasByteSizedElementType(type.getElementType())) {
emitCheckedArithmeticError(anchor, fieldName, "requires byte-sized element type");
return failure();
}
uint64_t elements = 1;
for (int64_t dim : type.getShape()) {
if (dim < 0) {
emitCheckedArithmeticError(anchor, fieldName, "requires nonnegative dimensions");
return failure();
}
auto nextElements = checkedMul(elements, static_cast<uint64_t>(dim), anchor, fieldName);
if (failed(nextElements))
return failure();
elements = *nextElements;
}
return checkedMul(
elements, static_cast<uint64_t>(getElementTypeSizeInBytes(type.getElementType())), anchor, fieldName);
}
FailureOr<uint64_t> getCheckedShapedTypeSizeInBytes(ShapedType type, Location loc, llvm::StringRef fieldName) {
if (!type.hasStaticShape()) {
emitCheckedArithmeticError(loc, fieldName, "requires static shaped type");
return failure();
}
if (!hasByteSizedElementType(type.getElementType())) {
emitCheckedArithmeticError(loc, fieldName, "requires byte-sized element type");
return failure();
}
uint64_t elements = 1;
for (int64_t dim : type.getShape()) {
if (dim < 0) {
emitCheckedArithmeticError(loc, fieldName, "requires nonnegative dimensions");
return failure();
}
auto nextElements = checkedMulAtLocation(elements, static_cast<uint64_t>(dim), loc, fieldName);
if (failed(nextElements))
return failure();
elements = *nextElements;
}
return checkedMulAtLocation(
elements, static_cast<uint64_t>(getElementTypeSizeInBytes(type.getElementType())), loc, fieldName);
}
int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName) {
if (value < std::numeric_limits<int32_t>::min() || value > std::numeric_limits<int32_t>::max()) {
emitCrashMessage(fieldName, "is outside representable range");
llvm_unreachable("PIM checked arithmetic failure");
}
return static_cast<int32_t>(value);
}
int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName) {
if (value > static_cast<uint64_t>(std::numeric_limits<int32_t>::max())) {
emitCrashMessage(fieldName, "is outside representable range");
llvm_unreachable("PIM checked arithmetic failure");
}
return static_cast<int32_t>(value);
}
uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName) {
if (value > static_cast<uint64_t>(std::numeric_limits<uint8_t>::max())) {
emitCrashMessage(fieldName, "is outside representable range");
llvm_unreachable("PIM checked arithmetic failure");
}
return static_cast<uint8_t>(value);
}
size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName) {
if (value < 0) {
emitCrashMessage(fieldName, "is outside representable range");
llvm_unreachable("PIM checked arithmetic failure");
}
return static_cast<size_t>(value);
}
size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName) {
if (rhs > std::numeric_limits<size_t>::max() - lhs) {
emitCrashMessage(fieldName, "addition overflow");
llvm_unreachable("PIM checked arithmetic failure");
}
return lhs + rhs;
}
size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName) {
if (lhs != 0 && rhs > std::numeric_limits<size_t>::max() / lhs) {
emitCrashMessage(fieldName, "multiplication overflow");
llvm_unreachable("PIM checked arithmetic failure");
}
return lhs * rhs;
}
} // namespace onnx_mlir::pim
@@ -0,0 +1,107 @@
#pragma once
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/StringRef.h"
#include <cstddef>
#include <cstdint>
#include <limits>
#include <type_traits>
namespace onnx_mlir::pim {
mlir::InFlightDiagnostic
emitCheckedArithmeticError(mlir::Operation* anchor, llvm::StringRef fieldName, llvm::StringRef message);
mlir::InFlightDiagnostic
emitCheckedArithmeticError(mlir::Location loc, llvm::StringRef fieldName, llvm::StringRef message);
template <typename To, typename From>
mlir::FailureOr<To> checkedCast(From value, mlir::Operation* anchor, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<To> && std::is_integral_v<From>, "checkedCast requires integral types");
using ToLimits = std::numeric_limits<To>;
if constexpr (std::is_signed_v<From> == std::is_signed_v<To>) {
if (value < static_cast<From>(ToLimits::min()) || value > static_cast<From>(ToLimits::max())) {
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
return mlir::failure();
}
}
else if constexpr (std::is_signed_v<From>) {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::make_unsigned_t<To>;
if (value < 0 || static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
return mlir::failure();
}
}
else {
using UnsignedFrom = std::make_unsigned_t<From>;
using UnsignedTo = std::conditional_t<std::is_signed_v<To>, std::make_unsigned_t<To>, To>;
if (static_cast<UnsignedFrom>(value) > static_cast<UnsignedTo>(ToLimits::max())) {
emitCheckedArithmeticError(anchor, fieldName, "is outside representable range");
return mlir::failure();
}
}
return static_cast<To>(value);
}
template <typename UInt>
mlir::FailureOr<UInt> checkedAdd(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>, "checkedAdd requires unsigned integral types");
if (rhs > std::numeric_limits<UInt>::max() - lhs) {
emitCheckedArithmeticError(anchor, fieldName, "addition overflow");
return mlir::failure();
}
return lhs + rhs;
}
template <typename UInt>
mlir::FailureOr<UInt> checkedMul(UInt lhs, UInt rhs, mlir::Operation* anchor, llvm::StringRef fieldName) {
static_assert(std::is_integral_v<UInt> && std::is_unsigned_v<UInt>, "checkedMul requires unsigned integral types");
if (lhs != 0 && rhs > std::numeric_limits<UInt>::max() / lhs) {
emitCheckedArithmeticError(anchor, fieldName, "multiplication overflow");
return mlir::failure();
}
return lhs * rhs;
}
mlir::FailureOr<int32_t> checkedI32(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<int32_t> checkedI32(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<uint8_t> checkedU8(uint64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<size_t> checkedSize(int64_t value, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, int64_t value, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Operation* anchor, uint64_t value, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, int64_t value, llvm::StringRef fieldName);
mlir::FailureOr<mlir::IntegerAttr>
getCheckedI32Attr(mlir::Builder& builder, mlir::Location loc, uint64_t value, llvm::StringRef fieldName);
mlir::FailureOr<uint64_t>
getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Operation* anchor, llvm::StringRef fieldName);
mlir::FailureOr<uint64_t>
getCheckedShapedTypeSizeInBytes(mlir::ShapedType type, mlir::Location loc, llvm::StringRef fieldName);
int32_t checkedI32OrCrash(int64_t value, llvm::StringRef fieldName);
int32_t checkedI32OrCrash(uint64_t value, llvm::StringRef fieldName);
uint8_t checkedU8OrCrash(uint64_t value, llvm::StringRef fieldName);
size_t checkedSizeOrCrash(int64_t value, llvm::StringRef fieldName);
size_t checkedAddOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName);
size_t checkedMulOrCrash(size_t lhs, size_t rhs, llvm::StringRef fieldName);
} // namespace onnx_mlir::pim
+27
View File
@@ -0,0 +1,27 @@
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include "src/Accelerators/PIM/Common/Support/DebugDump.hpp"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
namespace onnx_mlir {
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return;
std::string dialectsDir = outputDir + "/dialects";
createDirectory(dialectsDir);
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
mlir::OpPrintingFlags flags;
flags.elideLargeElementsAttrs().enableDebugInfo(true, false);
moduleOp.print(os, flags);
os.flush();
file.close();
}
} // namespace onnx_mlir
+13
View File
@@ -0,0 +1,13 @@
#pragma once
#include "mlir/IR/BuiltinOps.h"
#include <string>
namespace onnx_mlir {
/// Emits a MLIR snapshot under the current compiler output
/// directory for pass-level debugging.
void dumpModule(mlir::ModuleOp moduleOp, const std::string& name);
} // namespace onnx_mlir
+41
View File
@@ -0,0 +1,41 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
namespace onnx_mlir::pim {
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription) {
return op->emitOpError() << "requires statically shaped " << valueDescription;
}
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
llvm::StringRef valueDescription,
int64_t actualRank,
llvm::ArrayRef<int64_t> supportedRanks) {
auto diag = op->emitOpError() << "has unsupported rank " << actualRank << " for " << valueDescription;
if (supportedRanks.empty())
return diag;
diag << "; supported rank";
if (supportedRanks.size() != 1)
diag << 's';
diag << ' ';
llvm::interleaveComma(supportedRanks, diag, [&](int64_t rank) { diag << rank; });
return diag;
}
mlir::InFlightDiagnostic
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName) {
return op->emitOpError() << "references missing " << symbolKind << " `" << symbolName << "`";
}
mlir::LogicalResult emitFileSystemError(mlir::Location loc,
llvm::StringRef action,
llvm::StringRef path,
const std::error_code& errorCode) {
mlir::emitError(loc) << "failed to " << action << " `" << path << "`: " << errorCode.message();
return mlir::failure();
}
} // namespace onnx_mlir::pim
+64
View File
@@ -0,0 +1,64 @@
#pragma once
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Operation.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include <cstdint>
#include <system_error>
namespace onnx_mlir::pim {
struct CappedDiagnosticReporter {
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8)
: maxReportedFailures(maxReportedFailures) {}
template <typename EmitFn>
void report(mlir::Operation* op, EmitFn&& emit) {
numFailures++;
if (numFailures <= maxReportedFailures)
emit(op);
}
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
if (numFailures > maxReportedFailures)
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
}
void noteFailures(int64_t count) { numFailures += count; }
bool hasFailure() const { return numFailures != 0; }
private:
int64_t maxReportedFailures;
int64_t numFailures = 0;
};
/// Emits a consistent diagnostic for target paths that require static shapes.
mlir::InFlightDiagnostic emitUnsupportedStaticShapeDiagnostic(mlir::Operation* op, llvm::StringRef valueDescription);
/// Emits a consistent diagnostic for unsupported ranks while listing the ranks
/// accepted by the current lowering/codegen path.
mlir::InFlightDiagnostic emitUnsupportedRankDiagnostic(mlir::Operation* op,
llvm::StringRef valueDescription,
int64_t actualRank,
llvm::ArrayRef<int64_t> supportedRanks);
/// Emits a consistent diagnostic for missing symbol/global references.
mlir::InFlightDiagnostic
emitMissingSymbolDiagnostic(mlir::Operation* op, llvm::StringRef symbolKind, llvm::StringRef symbolName);
/// Converts a filesystem error into an MLIR failure diagnostic anchored at
/// the relevant IR location.
mlir::LogicalResult
emitFileSystemError(mlir::Location loc, llvm::StringRef action, llvm::StringRef path, const std::error_code& errorCode);
template <typename T>
mlir::LogicalResult failureOrToLogicalResult(const llvm::FailureOr<T>& value) {
return mlir::success(succeeded(value));
}
} // namespace onnx_mlir::pim
@@ -0,0 +1,24 @@
#include <filesystem>
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
#include "src/Compiler/CompilerOptions.hpp"
namespace onnx_mlir {
std::string getOutputDir() {
if (outputBaseName.empty() || outputBaseName == "-")
return {};
size_t lastSlash = outputBaseName.find_last_of('/');
if (lastSlash == std::string::npos)
return ".";
return outputBaseName.substr(0, lastSlash);
}
void createDirectory(const std::string& directory) {
std::error_code errorCode;
std::filesystem::create_directories(directory, errorCode);
assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data());
}
} // namespace onnx_mlir
@@ -0,0 +1,13 @@
#pragma once
#include <string>
namespace onnx_mlir {
/// Returns the directory that should hold PIM artifacts/debug dumps for the
/// current compiler invocation.
std::string getOutputDir();
void createDirectory(const std::string& directory);
} // namespace onnx_mlir
+64
View File
@@ -0,0 +1,64 @@
#include "llvm/Support/Format.h"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
namespace onnx_mlir {
std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRef extension) {
std::string outputDir = getOutputDir();
if (outputDir.empty())
return {};
std::string reportsDir = outputDir + "/reports";
createDirectory(reportsDir);
return std::fstream(reportsDir + "/" + name + "." + extension.str(), std::ios::out);
}
std::fstream openReportFile(const std::string& name) { return openReportFileWithExtension(name, "txt"); }
std::string formatReportMemory(uint64_t bytes) {
const char* units[] = {"B", "KB", "MB", "GB", "TB", "PB", "EB"};
int i = 0;
double size = static_cast<double>(bytes);
while (size >= 1024 && i < 6) {
size /= 1024;
i++;
}
std::string out;
llvm::raw_string_ostream rss(out);
rss << llvm::format("%.2f ", size) << units[i];
return rss.str();
}
void printReportFlatFields(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields) {
for (const ReportField& field : fields)
os << "\t" << field.label << ": " << field.value << "\n";
}
void printReportFieldBlock(llvm::raw_ostream& os, llvm::StringRef title, llvm::ArrayRef<ReportField> fields) {
os << "\t" << title << ":\n";
for (const ReportField& field : fields)
os << "\t " << field.label << ": " << field.value << "\n";
}
void printReportTotalsBlock(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields) {
os << "Totals:\n";
for (const ReportField& field : fields)
os << "\t" << field.label << ": " << field.value << "\n";
}
void printReportPerCoreAndTotalFields(llvm::raw_ostream& os,
llvm::ArrayRef<ReportField> perCoreFields,
llvm::ArrayRef<ReportField> totalFields) {
printReportFieldBlock(os, "Per core", perCoreFields);
printReportFieldBlock(os, "Total", totalFields);
}
void printReportEntrySeparator(llvm::raw_ostream& os, bool hasNextEntry) {
if (hasNextEntry)
os << "\n";
}
} // namespace onnx_mlir
+48
View File
@@ -0,0 +1,48 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <fstream>
#include <limits>
#include <string>
namespace onnx_mlir {
std::fstream openReportFile(const std::string& name);
std::fstream openReportFileWithExtension(const std::string& name, llvm::StringRef extension);
std::string formatReportMemory(uint64_t bytes);
struct ReportField {
std::string label;
std::string value;
};
void printReportFlatFields(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields);
void printReportFieldBlock(llvm::raw_ostream& os, llvm::StringRef title, llvm::ArrayRef<ReportField> fields);
void printReportTotalsBlock(llvm::raw_ostream& os, llvm::ArrayRef<ReportField> fields);
void printReportPerCoreAndTotalFields(llvm::raw_ostream& os,
llvm::ArrayRef<ReportField> perCoreFields,
llvm::ArrayRef<ReportField> totalFields);
void printReportEntrySeparator(llvm::raw_ostream& os, bool hasNextEntry);
template <typename EntryTy>
int32_t getFirstReportCoreId(const EntryTy& entry) {
if (entry.coreIds.empty())
return std::numeric_limits<int32_t>::max();
return entry.coreIds.front();
}
template <typename EntryRange>
void sortReportEntriesByFirstCore(EntryRange& entries) {
llvm::stable_sort(entries, [](const auto& lhs, const auto& rhs) {
int32_t lhsFirstCore = getFirstReportCoreId(lhs);
int32_t rhsFirstCore = getFirstReportCoreId(rhs);
if (lhsFirstCore != rhsFirstCore)
return lhsFirstCore < rhsFirstCore;
return lhs.id < rhs.id;
});
}
} // namespace onnx_mlir
+6
View File
@@ -15,7 +15,10 @@ add_pim_library(OMPimCompilerOptions
add_pim_library(OMPimCompilerUtils
PimCompilerUtils.cpp
PimArtifactWriter.cpp
PimCodeGen.cpp
PimMemoryLiveness.cpp
PimWeightEmitter.cpp
EXCLUDE_FROM_OM_LIBS
@@ -26,6 +29,9 @@ add_pim_library(OMPimCompilerUtils
OMPimCompilerOptions
OMPimCommon
OMPimBufferization
OMPimMemoryCoalescing
OMPimHostConstantFolding
OMPimVerification
OMPimPasses
OMONNXToSpatial
OMSpatialToPim
+106
View File
@@ -0,0 +1,106 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
#include <cstring>
#include <vector>
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimArtifactWriter.hpp"
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
using namespace llvm;
using namespace mlir;
namespace onnx_mlir {
OnnxMlirCompilerErrorCodes
writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory& memory, StringRef outputDirPath) {
auto memoryFilePath = (outputDirPath + "/memory.bin").str();
std::error_code errorCode;
raw_fd_ostream memoryFileStream(memoryFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening memory file " << memoryFilePath << ": " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
std::vector<char> memoryBuffer(memory.hostMem.getFirstAvailableAddress(), 0);
SmallPtrSet<Operation*, 16> writtenGlobals;
funcOp.walk([&](memref::GetGlobalOp getGlobalOp) {
if (hasWeightAlways(getGlobalOp))
return;
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp)
return;
if (!writtenGlobals.insert(globalOp.getOperation()).second)
return;
auto initialValue = globalOp.getInitialValue();
if (!initialValue)
return;
auto denseAttr = dyn_cast<DenseElementsAttr>(*initialValue);
if (!denseAttr)
return;
MemEntry memEntry = memory.hostMem.getMemEntry({getGlobalOp.getResult(), std::nullopt});
ArrayRef<char> rawData = denseAttr.getRawData();
char* dst = memoryBuffer.data() + memEntry.address;
if (denseAttr.isSplat()) {
size_t elementSize = rawData.size();
assert(elementSize * getGlobalOp.getType().getNumElements() == memEntry.size && "Data size mismatch");
for (size_t offset = 0; offset < memEntry.size; offset += elementSize)
std::memcpy(dst + offset, rawData.data(), std::min(elementSize, memEntry.size - offset));
}
else {
assert(rawData.size() == memEntry.size && "Data size mismatch");
std::memcpy(dst, rawData.data(), rawData.size());
}
});
memoryFileStream.write(memoryBuffer.data(), memoryBuffer.size());
memoryFileStream.close();
return CompilerSuccess;
}
OnnxMlirCompilerErrorCodes writeConfigJson(func::FuncOp funcOp,
PimAcceleratorMemory& memory,
size_t maxCoreId,
json::Object xbarsPerArrayGroup,
StringRef outputDirPath) {
json::Object configJson;
configJson["core_cnt"] = maxCoreId + 1;
configJson["xbar_size"] = {crossbarSize.getValue(), crossbarSize.getValue()};
configJson["array_group_map"] = std::move(xbarsPerArrayGroup);
json::Array inputsAddresses;
for (BlockArgument input : funcOp.getArguments())
inputsAddresses.push_back(memory.getValueAddress(input));
configJson["inputs_addresses"] = std::move(inputsAddresses);
json::Array outputsAddresses;
for (func::ReturnOp returnOp : funcOp.getOps<func::ReturnOp>())
for (mlir::Value output : returnOp.getOperands())
outputsAddresses.push_back(memory.getValueAddress(output));
configJson["outputs_addresses"] = std::move(outputsAddresses);
auto configPath = (outputDirPath + "/config.json").str();
std::error_code errorCode;
raw_fd_ostream jsonOS(configPath, errorCode);
if (errorCode) {
errs() << "Error while opening config file: " << errorCode.message() << '\n';
return InvalidOutputFileAccess;
}
jsonOS << json::Value(std::move(configJson)) << '\n';
jsonOS.close();
return CompilerSuccess;
}
} // namespace onnx_mlir
+25
View File
@@ -0,0 +1,25 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/JSON.h"
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
namespace onnx_mlir {
class PimAcceleratorMemory;
OnnxMlirCompilerErrorCodes writeMemoryBinary(mlir::ModuleOp moduleOp,
mlir::func::FuncOp funcOp,
PimAcceleratorMemory& memory,
llvm::StringRef outputDirPath);
OnnxMlirCompilerErrorCodes writeConfigJson(mlir::func::FuncOp funcOp,
PimAcceleratorMemory& memory,
size_t maxCoreId,
llvm::json::Object xbarsPerArrayGroup,
llvm::StringRef outputDirPath);
} // namespace onnx_mlir
+369
View File
@@ -0,0 +1,369 @@
#pragma once
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h"
#include <array>
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
namespace onnx_mlir::pim_binary {
inline constexpr char kMagic[4] = {'P', 'I', 'M', 'B'};
inline constexpr uint32_t kVersion = 1;
inline constexpr uint64_t kCountOffset = 8;
inline constexpr size_t kHeaderSize = 12;
inline constexpr size_t kRecordSize = 20;
enum class Opcode : uint32_t {
nop = 0,
sldi = 1,
sld = 2,
sadd = 3,
ssub = 4,
smul = 5,
saddi = 6,
smuli = 7,
setbw = 8,
mvmul = 9,
vvadd = 10,
vvsub = 11,
vvmul = 12,
vvdmul = 13,
vvmax = 14,
vvsll = 15,
vvsra = 16,
vavg = 17,
vrelu = 18,
vtanh = 19,
vsigm = 20,
vsoftmax = 21,
vmv = 22,
vrsu = 23,
vrsl = 24,
ld = 25,
st = 26,
lldi = 27,
lmv = 28,
send = 29,
recv = 30,
wait = 31,
sync = 32,
};
struct InstructionRecord {
Opcode opcode = Opcode::nop;
uint8_t rd = 0;
uint8_t r1 = 0;
int32_t r2OrImm = 0;
int32_t generic1 = 0;
int32_t generic2 = 0;
int32_t generic3 = 0;
uint8_t flags = 0;
};
inline void writeUint32LE(llvm::raw_ostream& os, uint32_t value) {
std::array<char, sizeof(uint32_t)> bytes;
llvm::support::endian::write32le(bytes.data(), value);
os.write(bytes.data(), bytes.size());
}
inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { writeUint32LE(os, static_cast<uint32_t>(value)); }
inline void writeHeader(llvm::raw_ostream& os) {
os.write(kMagic, sizeof(kMagic));
writeUint32LE(os, kVersion);
writeUint32LE(os, 0);
}
inline void patchInstructionCount(llvm::raw_pwrite_stream& os, uint32_t instructionCount) {
std::array<char, sizeof(uint32_t)> bytes;
llvm::support::endian::write32le(bytes.data(), instructionCount);
os.pwrite(bytes.data(), bytes.size(), kCountOffset);
}
inline void writeInstructionRecord(llvm::raw_ostream& os, const InstructionRecord& record) {
os << static_cast<char>(static_cast<uint8_t>(record.opcode));
os << static_cast<char>(record.rd);
os << static_cast<char>(record.r1);
os << static_cast<char>(record.flags);
writeInt32LE(os, record.r2OrImm);
writeInt32LE(os, record.generic1);
writeInt32LE(os, record.generic2);
writeInt32LE(os, record.generic3);
}
inline int32_t toI32(int64_t value) { return onnx_mlir::pim::checkedI32OrCrash(value, "binary field"); }
inline uint8_t toU8(int64_t value) {
return onnx_mlir::pim::checkedU8OrCrash(static_cast<uint64_t>(value), "binary field");
}
inline int32_t getOptionalInt(const llvm::json::Object& object, llvm::StringRef key, int32_t defaultValue = 0) {
if (std::optional<int64_t> value = object.getInteger(key))
return toI32(*value);
return defaultValue;
}
inline Opcode opcodeFromString(llvm::StringRef opName) {
if (opName == "nop")
return Opcode::nop;
if (opName == "sldi")
return Opcode::sldi;
if (opName == "sld")
return Opcode::sld;
if (opName == "sadd")
return Opcode::sadd;
if (opName == "ssub")
return Opcode::ssub;
if (opName == "smul")
return Opcode::smul;
if (opName == "saddi")
return Opcode::saddi;
if (opName == "smuli")
return Opcode::smuli;
if (opName == "setbw")
return Opcode::setbw;
if (opName == "mvmul")
return Opcode::mvmul;
if (opName == "vvadd")
return Opcode::vvadd;
if (opName == "vvsub")
return Opcode::vvsub;
if (opName == "vvmul")
return Opcode::vvmul;
if (opName == "vvdmul")
return Opcode::vvdmul;
if (opName == "vvmax")
return Opcode::vvmax;
if (opName == "vvsll")
return Opcode::vvsll;
if (opName == "vvsra")
return Opcode::vvsra;
if (opName == "vavg")
return Opcode::vavg;
if (opName == "vrelu")
return Opcode::vrelu;
if (opName == "vtanh")
return Opcode::vtanh;
if (opName == "vsigm")
return Opcode::vsigm;
if (opName == "vsoftmax")
return Opcode::vsoftmax;
if (opName == "vmv")
return Opcode::vmv;
if (opName == "vrsu")
return Opcode::vrsu;
if (opName == "vrsl")
return Opcode::vrsl;
if (opName == "ld")
return Opcode::ld;
if (opName == "st")
return Opcode::st;
if (opName == "lldi")
return Opcode::lldi;
if (opName == "lmv")
return Opcode::lmv;
if (opName == "send")
return Opcode::send;
if (opName == "recv")
return Opcode::recv;
if (opName == "wait")
return Opcode::wait;
if (opName == "sync")
return Opcode::sync;
llvm_unreachable("Unsupported PIM binary opcode");
}
inline llvm::StringRef opcodeToString(Opcode opcode) {
switch (opcode) {
case Opcode::nop: return "nop";
case Opcode::sldi: return "sldi";
case Opcode::sld: return "sld";
case Opcode::sadd: return "sadd";
case Opcode::ssub: return "ssub";
case Opcode::smul: return "smul";
case Opcode::saddi: return "saddi";
case Opcode::smuli: return "smuli";
case Opcode::setbw: return "setbw";
case Opcode::mvmul: return "mvmul";
case Opcode::vvadd: return "vvadd";
case Opcode::vvsub: return "vvsub";
case Opcode::vvmul: return "vvmul";
case Opcode::vvdmul: return "vvdmul";
case Opcode::vvmax: return "vvmax";
case Opcode::vvsll: return "vvsll";
case Opcode::vvsra: return "vvsra";
case Opcode::vavg: return "vavg";
case Opcode::vrelu: return "vrelu";
case Opcode::vtanh: return "vtanh";
case Opcode::vsigm: return "vsigm";
case Opcode::vsoftmax: return "vsoftmax";
case Opcode::vmv: return "vmv";
case Opcode::vrsu: return "vrsu";
case Opcode::vrsl: return "vrsl";
case Opcode::ld: return "ld";
case Opcode::st: return "st";
case Opcode::lldi: return "lldi";
case Opcode::lmv: return "lmv";
case Opcode::send: return "send";
case Opcode::recv: return "recv";
case Opcode::wait: return "wait";
case Opcode::sync: return "sync";
}
llvm_unreachable("Unsupported PIM binary opcode");
}
inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruction) {
InstructionRecord record;
std::optional<llvm::StringRef> opName = instruction.getString("op");
assert(opName && "Missing op field in PIM instruction");
record.opcode = opcodeFromString(*opName);
record.rd = toU8(getOptionalInt(instruction, "rd"));
record.r1 = toU8(getOptionalInt(instruction, "rs1"));
switch (record.opcode) {
case Opcode::sldi:
case Opcode::saddi:
case Opcode::smuli:
case Opcode::lldi: record.r2OrImm = getOptionalInt(instruction, "imm"); break;
case Opcode::mvmul:
record.r2OrImm = getOptionalInt(instruction, "mbiw");
record.generic1 = getOptionalInt(instruction, "relu");
record.generic2 = getOptionalInt(instruction, "group");
break;
case Opcode::setbw:
record.generic1 = getOptionalInt(instruction, "ibiw");
record.generic2 = getOptionalInt(instruction, "obiw");
break;
case Opcode::send:
case Opcode::recv:
record.r2OrImm = getOptionalInt(instruction, "core");
record.generic3 = getOptionalInt(instruction, "size");
break;
default: record.r2OrImm = getOptionalInt(instruction, "rs2"); break;
}
if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) {
if (auto* offsetValue = instruction.getObject("offset")) {
record.generic1 = getOptionalInt(*offsetValue, "offset_select");
record.generic2 = getOptionalInt(*offsetValue, "offset_value");
}
}
if (instruction.get("len"))
record.generic3 = getOptionalInt(instruction, "len");
else if (instruction.get("size") && record.opcode != Opcode::send && record.opcode != Opcode::recv)
record.generic3 = getOptionalInt(instruction, "size");
return record;
}
inline llvm::json::Object makeInstructionJson(const InstructionRecord& record) {
llvm::json::Object instruction;
instruction["op"] = opcodeToString(record.opcode).str();
auto addOffset = [&](int32_t offsetSelect, int32_t offsetValue) {
llvm::json::Object offset;
offset["offset_select"] = offsetSelect;
offset["offset_value"] = offsetValue;
instruction["offset"] = std::move(offset);
};
switch (record.opcode) {
case Opcode::sldi:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["imm"] = record.r2OrImm;
break;
case Opcode::sld:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["rs1"] = static_cast<int64_t>(record.r1);
addOffset(record.generic1, record.generic2);
break;
case Opcode::sadd:
case Opcode::ssub:
case Opcode::smul:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["rs1"] = static_cast<int64_t>(record.r1);
instruction["rs2"] = record.r2OrImm;
break;
case Opcode::saddi:
case Opcode::smuli:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["rs1"] = static_cast<int64_t>(record.r1);
instruction["imm"] = record.r2OrImm;
break;
case Opcode::setbw:
instruction["ibiw"] = record.generic1;
instruction["obiw"] = record.generic2;
break;
case Opcode::mvmul:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["rs1"] = static_cast<int64_t>(record.r1);
instruction["mbiw"] = record.r2OrImm;
instruction["relu"] = record.generic1;
instruction["group"] = record.generic2;
break;
case Opcode::vvadd:
case Opcode::vvsub:
case Opcode::vvmul:
case Opcode::vvdmul:
case Opcode::vvmax:
case Opcode::vvsll:
case Opcode::vvsra:
case Opcode::vavg:
case Opcode::vmv:
case Opcode::vrsu:
case Opcode::vrsl:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["rs1"] = static_cast<int64_t>(record.r1);
instruction["rs2"] = record.r2OrImm;
addOffset(record.generic1, record.generic2);
instruction["len"] = record.generic3;
break;
case Opcode::vrelu:
case Opcode::vtanh:
case Opcode::vsigm:
case Opcode::vsoftmax:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["rs1"] = static_cast<int64_t>(record.r1);
addOffset(record.generic1, record.generic2);
instruction["len"] = record.generic3;
break;
case Opcode::ld:
case Opcode::st:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["rs1"] = static_cast<int64_t>(record.r1);
addOffset(record.generic1, record.generic2);
instruction["size"] = record.generic3;
break;
case Opcode::lldi:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["imm"] = record.r2OrImm;
addOffset(record.generic1, record.generic2);
instruction["len"] = record.generic3;
break;
case Opcode::lmv:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["rs1"] = static_cast<int64_t>(record.r1);
addOffset(record.generic1, record.generic2);
instruction["len"] = record.generic3;
break;
case Opcode::send:
case Opcode::recv:
instruction["rd"] = static_cast<int64_t>(record.rd);
instruction["core"] = record.r2OrImm;
addOffset(record.generic1, record.generic2);
instruction["size"] = record.generic3;
break;
case Opcode::wait:
case Opcode::sync:
case Opcode::nop: break;
}
return instruction;
}
} // namespace onnx_mlir::pim_binary
File diff suppressed because it is too large Load Diff
+160 -20
View File
@@ -1,10 +1,24 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include <limits>
#include <optional>
#include <string>
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
@@ -14,57 +28,155 @@ struct MemEntry {
size_t size;
};
class PimMemory {
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
struct PhysicalSlotInfo {
size_t id = 0;
size_t address = 0;
size_t size = 0;
};
struct MemoryPlanArtifacts {
std::string textReport;
};
struct MemoryValueKey {
mlir::Value value;
std::optional<unsigned> lane;
bool operator==(const MemoryValueKey& other) const { return value == other.value && lane == other.lane; }
};
struct MemoryReportRow {
uint64_t numAlloca = 0;
uint64_t sizeAlloca = 0;
uint64_t numGlobal = 0;
uint64_t sizeGlobal = 0;
bool operator==(const MemoryReportRow& other) const {
return numAlloca == other.numAlloca && sizeAlloca == other.sizeAlloca && numGlobal == other.numGlobal
&& sizeGlobal == other.sizeGlobal;
}
};
enum class MemoryReportKind {
None,
Alloca,
Global,
Input
};
struct PendingMemEntry {
MemEntry memEntry;
MemoryValueKey key;
MemoryReportKind reportKind = MemoryReportKind::None;
};
struct MemoryReportEntry {
enum class Kind {
Core,
Batch
};
Kind kind = Kind::Core;
uint64_t id = 0;
llvm::SmallVector<int32_t, 8> coreIds;
MemoryReportRow row;
uint64_t totalAllocaCount = 0;
uint64_t totalAllocaBytes = 0;
};
class PimMemory {
llvm::SmallVector<PendingMemEntry, 32> memEntries;
llvm::SmallVector<PhysicalSlotInfo, 32> localPhysicalSlots;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> ownedMemEntriesMap;
MemoryReportRow reportRow;
MemoryPlanArtifacts livenessArtifacts;
size_t maxSize = 0; // 0 for unbounded memory
size_t startAddress = 0;
size_t minAlignment = 4;
size_t firstAvailableAddress = 0;
size_t nextPhysicalSlotId = 0;
MemEntry* gatherMemEntry(mlir::Value value);
MemEntry* gatherMemEntry(mlir::Value value, std::optional<unsigned> lane = std::nullopt);
void allocateGatheredMemory();
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
void allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry, MemoryReportKind reportKind);
PhysicalSlotInfo allocatePhysicalSlot(size_t slotSize, const MemoryValueKey& key);
public:
PimMemory(llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap)
PimMemory(llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap)
: globalMemEntriesMap(globalMemEntriesMap) {}
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
void allocateCore(mlir::Operation* op);
void allocateCore(mlir::Operation* op, std::optional<unsigned> lane = std::nullopt);
MemoryReportRow getReportRow() const;
const MemoryPlanArtifacts& getLivenessArtifacts() const { return livenessArtifacts; }
void remove(mlir::Value val);
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
MemEntry getMemEntry(mlir::Value value) const;
MemEntry getMemEntry(const MemoryValueKey& key) const;
};
class PimAcceleratorMemory {
public:
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> memEntriesMap;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> memEntriesMap;
PimMemory hostMem;
private:
llvm::SmallDenseMap<size_t, PimMemory> deviceMem;
std::fstream fileReport;
std::optional<MemoryReportRow> hostReportRow;
llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
uint64_t totalWeightBytes = 0;
mutable llvm::DenseMap<mlir::Value, CompiledIndexExpr> compiledIndexExprs;
mutable llvm::DenseMap<mlir::Value, CompiledAddressExpr> compiledAddressExprs;
public:
PimAcceleratorMemory()
: hostMem(memEntriesMap) {}
: hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {}
PimAcceleratorMemory(const llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& initialMemEntries, bool enableReport)
: memEntriesMap(initialMemEntries),
hostMem(memEntriesMap),
fileReport(enableReport ? openReportFile("memory_report") : std::fstream()) {}
PimMemory& getOrCreateDeviceMem(size_t id);
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
size_t getValueAddress(mlir::Value value,
const StaticValueKnowledge& knowledge = {},
std::optional<unsigned> lane = std::nullopt) const;
llvm::FailureOr<int64_t> getIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
void reportHost();
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
void recordBatchReport(uint64_t batchId,
llvm::ArrayRef<int32_t> coreIds,
const MemoryReportRow& perCoreRow,
uint64_t totalAllocaCount,
uint64_t totalAllocaBytes);
void setTotalWeightBytes(uint64_t bytes) { totalWeightBytes = bytes; }
void flushReport();
void clean(mlir::Operation* op);
};
struct CoreEmissionJob {
mlir::Operation* coreLikeOp = nullptr;
size_t originalCoreId = 0;
size_t emittedCoreId = 0;
llvm::SmallVector<unsigned, 4> lanes;
std::optional<uint64_t> batchReportId;
};
class PimCodeGen {
PimAcceleratorMemory& memory;
llvm::raw_fd_ostream& coreFileStream;
llvm::raw_fd_ostream& coreBinaryStream;
llvm::raw_fd_ostream* coreJsonStream;
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
std::optional<unsigned> batchLane;
mutable uint32_t emittedInstructionCount = 0;
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getValueAddress(value, knowledge);
return memory.getValueAddress(value, knowledge, batchLane);
}
size_t remapCoreId(size_t coreId) const;
static llvm::json::Object createEmptyOffset();
void emitInstruction(llvm::json::Object instruction) const;
void emitInstruction(const pim_binary::InstructionRecord& instruction) const;
void genSetRegisterImmediateUnsigned(size_t registerNumber, size_t immediate) const;
void setupRd(size_t rdAddress, size_t rdOffset) const;
@@ -83,8 +195,17 @@ class PimCodeGen {
void emitMvmOp(size_t groupId, size_t rdAddr, size_t rdOffset, size_t rs1Addr, size_t rs1Offset) const;
public:
PimCodeGen(PimAcceleratorMemory& memory, llvm::raw_fd_ostream& coreJson)
: memory(memory), coreFileStream(coreJson) {}
PimCodeGen(PimAcceleratorMemory& memory,
llvm::raw_fd_ostream& coreBinary,
llvm::raw_fd_ostream* coreJson,
const llvm::DenseMap<size_t, size_t>& emittedCoreIds)
: memory(memory), coreBinaryStream(coreBinary), coreJsonStream(coreJson), emittedCoreIds(emittedCoreIds) {}
uint32_t getEmittedInstructionCount() const { return emittedInstructionCount; }
void setBatchLane(std::optional<unsigned> lane) { batchLane = lane; }
llvm::FailureOr<int64_t> indexOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getIndexValue(value, knowledge);
}
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
@@ -92,6 +213,7 @@ public:
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
template <typename MVMTy>
void codeGenMVMLikeOp(size_t mvmId, MVMTy mvmLikeOp, bool transposeMatrix, const StaticValueKnowledge& knowledge);
@@ -106,9 +228,27 @@ public:
void codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowledge& knowledge) const;
void codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowledge& knowledge) const;
void codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticValueKnowledge& knowledge) const;
void codeGetGlobalOp(mlir::memref::GetGlobalOp getGlobalOp, const StaticValueKnowledge& knowledge) const;
void codeGenTransposeOp(pim::PimTransposeOp transposeOp, const StaticValueKnowledge& knowledge) const;
};
OnnxMlirCompilerErrorCodes compileToPimJson(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
OnnxMlirCompilerErrorCodes compileToPimCode(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
} // namespace onnx_mlir
namespace llvm {
template <>
struct DenseMapInfo<onnx_mlir::MemoryValueKey> {
static onnx_mlir::MemoryValueKey getEmptyKey() { return {DenseMapInfo<mlir::Value>::getEmptyKey(), 0}; }
static onnx_mlir::MemoryValueKey getTombstoneKey() { return {DenseMapInfo<mlir::Value>::getTombstoneKey(), 0}; }
static unsigned getHashValue(const onnx_mlir::MemoryValueKey& key) {
return hash_combine(key.value, key.lane.value_or(std::numeric_limits<unsigned>::max()));
}
static bool isEqual(const onnx_mlir::MemoryValueKey& lhs, const onnx_mlir::MemoryValueKey& rhs) { return lhs == rhs; }
};
} // namespace llvm
+42 -22
View File
@@ -1,16 +1,5 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/
#include "llvm/Support/ErrorHandling.h"
//===------------------------- PimCompilerOptions.cpp --------------------===//
//
// Copyright 2022 The IBM Research Authors.
//
// =============================================================================
//
// Compiler Options for PIM
//
//===----------------------------------------------------------------------===//
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#define DEBUG_TYPE "PimCompilerOptions"
@@ -26,36 +15,67 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
llvm::cl::init(EmitPimCodegen),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimMergeSchedulerType>
pimMergeScheduler("pim-merge-scheduler",
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
llvm::cl::init(MergeSchedulerPeft),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport(
"pim-memory-report",
llvm::cl::desc("Emit a human-readable PIM memory planning report"),
llvm::cl::values(clEnumValN(PimMemoryReportNone, "none", "Do not emit any PIM memory planning report")),
llvm::cl::values(
clEnumValN(PimMemoryReportSummary, "summary", "Emit a concise slot reuse report with key offenders")),
llvm::cl::values(clEnumValN(PimMemoryReportFull, "full", "Emit the full detailed PIM memory planning report")),
llvm::cl::init(PimMemoryReportNone),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool>
pimOnlyCodegen("pim-only-codegen",
llvm::cl::desc("Only generate code for PIM (assume input is already in bufferized PIM IR)"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool>
pimDisableMemoryCoalescing("pim-disable-memory-coalescing",
llvm::cl::desc("Skip the PIM memory coalescing pass (developer diagnostic option)"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<bool> useExperimentalConvImpl("use-experimental-conv-impl",
llvm::cl::desc("Use experimental implementation for convolution"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<size_t>
crossbarSize("crossbar-size", llvm::cl::desc("Width and heigth of a single crossbar"), llvm::cl::init(2));
llvm::cl::opt<bool> pimEmitJson("pim-emit-json",
llvm::cl::desc("Also emit per-core JSON instruction files alongside binary .pim files"),
llvm::cl::init(false),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<size_t>
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(2));
crossbarSize("crossbar-size", llvm::cl::desc("Width and height of a single crossbar"), llvm::cl::init(2));
llvm::cl::opt<size_t>
crossbarCountInCore("crossbar-count", llvm::cl::desc("Number of crossbars in each core"), llvm::cl::init(256));
llvm::cl::opt<long> coresCount("core-count",
llvm::cl::desc("Number of cores in the chip. `-1` to use the minimum amount of cores."),
llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
llvm::cl::init(-1));
llvm::cl::opt<size_t> dcpCriticalWindowSize(
"dcp-critical-window-size",
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
"Use 0 to run the legacy full-graph DCP analysis."),
llvm::cl::init(1024));
llvm::cl::opt<bool>
ignoreConcatError("ignore-concat-error",
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
llvm::cl::init(false));
bool hasExplicitPimCoreCount() { return coresCount.getNumOccurrences() != 0; }
void verifyExplicitPimCoreCount() {
if (!hasExplicitPimCoreCount())
llvm::report_fatal_error("PIM compilation requires an explicit --core-count=<positive integer>");
if (coresCount.getValue() <= 0)
llvm::report_fatal_error("PIM compilation requires --core-count to be a positive integer");
}
} // namespace onnx_mlir
+17 -1
View File
@@ -20,16 +20,32 @@ typedef enum {
EmitPimCodegen = 3
} PimEmissionTargetType;
typedef enum {
MergeSchedulerPeft = 0,
} PimMergeSchedulerType;
typedef enum {
PimMemoryReportNone = 0,
PimMemoryReportSummary = 1,
PimMemoryReportFull = 2,
} PimMemoryReportLevel;
extern llvm::cl::OptionCategory OnnxMlirOptions;
extern llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget;
extern llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler;
extern llvm::cl::opt<PimMemoryReportLevel> pimMemoryReport;
extern llvm::cl::opt<bool> pimOnlyCodegen;
extern llvm::cl::opt<bool> pimDisableMemoryCoalescing;
extern llvm::cl::opt<bool> useExperimentalConvImpl;
extern llvm::cl::opt<bool> pimEmitJson;
extern llvm::cl::opt<size_t> crossbarSize;
extern llvm::cl::opt<size_t> crossbarCountInCore;
extern llvm::cl::opt<long> coresCount;
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
bool hasExplicitPimCoreCount();
void verifyExplicitPimCoreCount();
// This option, by default set to false, will ignore an error when resolving a
// specific tiles of the operands of a concat. This specific case is when the
+5 -7
View File
@@ -17,6 +17,7 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
PassManager& pm,
EmissionTargetType& emissionTarget,
std::string outputNameNoExt) {
verifyExplicitPimCoreCount();
if (pimOnlyCodegen) {
// Skip all the lowering passes and directly generate code for PIM.
@@ -29,31 +30,28 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitSpatial) {
pm.addPass(createONNXToSpatialPass());
pm.addPass(createMergeComputeNodesPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Onnx lowered to Spatial"));
}
if (pimEmissionTarget >= EmitPim) {
pm.addPass(createSpatialToPimPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Spatial lowered to Pim"));
}
if (pimEmissionTarget >= EmitPimBufferized) {
pm.addPass(createPimBufferizationPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim bufferized"));
}
if (pimEmissionTarget >= EmitPimCodegen) {
pm.addPass(createPimHostConstantFoldingPass());
pm.addPass(createMessagePass("Pim host constants folded"));
pm.addPass(createPimMaterializeHostConstantsPass());
if (!pimDisableMemoryCoalescing)
pm.addPass(createPimMemoryCoalescingPass());
pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimJsonPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim json code emitted"));
pm.addPass(createEmitPimCodePass());
pm.addPass(createMessagePass("Pim code emitted"));
}
}
+733
View File
@@ -0,0 +1,733 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>
#include <string>
#include <tuple>
#include <utility>
#include "Common/Support/CheckedArithmetic.hpp"
#include "Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimMemoryLiveness.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm;
using namespace mlir;
using namespace onnx_mlir;
namespace {
static std::optional<unsigned> getLaneForMemoryValue(mlir::Value value, std::optional<unsigned> lane) {
if (!lane)
return std::nullopt;
auto allocOp = value.getDefiningOp<memref::AllocOp>();
if (!allocOp || !allocOp->getParentOfType<pim::PimCoreBatchOp>())
return std::nullopt;
return lane;
}
static MemoryValueKey getMemoryValueKey(mlir::Value value, std::optional<unsigned> lane = std::nullopt) {
return {value, getLaneForMemoryValue(value, lane)};
}
struct MemoryTouchInterval {
uint64_t start = 0;
uint64_t end = 0;
Operation* startOp = nullptr;
Operation* endOp = nullptr;
Operation* firstTouchOp = nullptr;
Operation* lastTouchOp = nullptr;
uint64_t firstTouchPosition = 0;
uint64_t lastTouchPosition = 0;
bool hasRuntimeUse = false;
bool startUsedAllocFallback = false;
bool endUsedFallback = false;
bool escapesLoop = false;
std::string fallbackReason;
llvm::SmallVector<std::string, 8> aliasesFollowed;
};
struct OperationOrdering {
llvm::DenseMap<Operation*, uint64_t> position;
llvm::DenseMap<Operation*, uint64_t> subtreeEnd;
uint64_t nextPosition = 0;
};
static std::string printValueToString(mlir::Value value) {
std::string text;
llvm::raw_string_ostream os(text);
value.print(os);
os.flush();
return text;
}
static std::string printOperationToString(Operation* op) {
if (!op)
return "<none>";
std::string text;
llvm::raw_string_ostream os(text);
op->print(os);
os.flush();
return text;
}
static std::string printLocationToString(Location loc) {
std::string text;
llvm::raw_string_ostream os(text);
loc.print(os);
os.flush();
return text;
}
static std::string collapseWhitespace(StringRef text) {
std::string out;
out.reserve(text.size());
bool lastWasSpace = false;
for (char c : text) {
bool isSpace = c == ' ' || c == '\n' || c == '\t' || c == '\r';
if (isSpace) {
if (!lastWasSpace && !out.empty())
out.push_back(' ');
lastWasSpace = true;
continue;
}
out.push_back(c);
lastWasSpace = false;
}
return out;
}
static std::string abbreviate(StringRef text, size_t maxLen) {
if (text.size() <= maxLen)
return text.str();
return (text.take_front(maxLen - 3) + "...").str();
}
static std::string summarizeValue(mlir::Value value, size_t maxLen = 72) {
return abbreviate(collapseWhitespace(printValueToString(value)), maxLen);
}
static std::string summarizeOperation(Operation* op, size_t maxLen = 96) {
if (!op)
return "<none>";
std::string prefix = op->getName().getStringRef().str();
std::string full = collapseWhitespace(printOperationToString(op));
if (full == prefix)
return prefix;
return abbreviate(prefix + " :: " + full, maxLen);
}
static std::string summarizeLocation(Location loc, size_t maxLen = 88) {
return abbreviate(collapseWhitespace(printLocationToString(loc)), maxLen);
}
static void assignOperationOrdering(Operation* op, OperationOrdering& ordering) {
uint64_t position = ordering.nextPosition++;
ordering.position[op] = position;
uint64_t end = position;
for (Region& region : op->getRegions())
for (Block& block : region)
for (Operation& nestedOp : block) {
assignOperationOrdering(&nestedOp, ordering);
end = std::max(end, ordering.subtreeEnd.lookup(&nestedOp));
}
ordering.subtreeEnd[op] = end;
}
static OperationOrdering buildOperationOrdering(Operation* coreLikeOp) {
OperationOrdering ordering;
if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty())
return ordering;
for (Operation& op : coreLikeOp->getRegion(0).front())
assignOperationOrdering(&op, ordering);
return ordering;
}
static bool isSupportedAliasOp(Operation* op) {
return isa<memref::SubViewOp, memref::CastOp, memref::CollapseShapeOp, memref::ExpandShapeOp>(op);
}
static bool isRuntimeMemoryTouchOp(Operation* op) {
return isa<pim::PimMemCopyHostToDevOp,
pim::PimMemCopyDevToHostOp,
pim::PimMemCopyOp,
pim::PimReceiveOp,
pim::PimSendOp,
pim::PimConcatOp,
pim::PimVMMOp,
pim::PimTransposeOp,
pim::PimVVAddOp,
pim::PimVVSubOp,
pim::PimVVMulOp,
pim::PimVVMaxOp,
pim::PimVVDMulOp,
pim::PimVAvgOp,
pim::PimVReluOp,
pim::PimVTanhOp,
pim::PimVSigmOp,
pim::PimVSoftmaxOp>(op);
}
static bool isIgnoredLivenessUser(Operation* op) {
return isSupportedAliasOp(op) || isa<scf::ForOp, scf::YieldOp, memref::DeallocOp>(op) || isCoreStaticAddressOp(op);
}
static bool isWithin(mlir::Value value, Region* region) {
if (!region)
return false;
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParent() == region;
if (Operation* definingOp = value.getDefiningOp())
return definingOp->getParentRegion() == region || region->isAncestor(definingOp->getParentRegion());
return false;
}
static bool isNestedAllocation(Operation* coreLikeOp, memref::AllocOp allocOp) {
if (!coreLikeOp || coreLikeOp->getNumRegions() != 1 || coreLikeOp->getRegion(0).empty())
return false;
return allocOp->getBlock() != &coreLikeOp->getRegion(0).front();
}
static void addFallbackReason(std::string& reason, StringRef newReason) {
if (newReason.empty())
return;
if (!reason.empty())
reason += "; ";
reason += newReason.str();
}
static void appendAliasDescription(llvm::SmallVectorImpl<std::string>& aliases, mlir::Value value) {
std::string text = printValueToString(value);
if (!llvm::is_contained(aliases, text))
aliases.push_back(std::move(text));
}
struct OrderedTouchRange {
uint64_t start = 0;
uint64_t end = 0;
Operation* startOp = nullptr;
Operation* endOp = nullptr;
bool escapedLoop = false;
};
static OrderedTouchRange
getEffectiveTouchRange(mlir::Value definingValue, Operation* user, const OperationOrdering& ordering) {
OrderedTouchRange range {ordering.position.lookup(user), ordering.position.lookup(user), user, user, false};
for (Operation* current = user; current; current = current->getParentOp()) {
auto forOp = dyn_cast<scf::ForOp>(current);
if (!forOp || isWithin(definingValue, &forOp.getRegion()))
continue;
range.start = std::min(range.start, ordering.position.lookup(forOp));
range.end = std::max(range.end, ordering.subtreeEnd.lookup(forOp));
range.startOp = forOp;
range.endOp = forOp;
range.escapedLoop = true;
}
return range;
}
static MemoryTouchInterval
computeMemoryTouchInterval(memref::AllocOp allocOp, const OperationOrdering& ordering, uint64_t fallbackEnd) {
MemoryTouchInterval interval;
interval.start = ordering.position.lookup(allocOp);
interval.end = interval.start;
interval.startOp = allocOp;
interval.endOp = allocOp;
SmallPtrSet<mlir::Value, 16> visitedValues;
SmallPtrSet<Operation*, 32> visitedUsers;
SmallVector<mlir::Value> pendingValues;
pendingValues.push_back(allocOp.getResult());
auto parentLoop = allocOp->getParentOfType<scf::ForOp>();
while (!pendingValues.empty()) {
mlir::Value value = pendingValues.pop_back_val();
if (!visitedValues.insert(value).second)
continue;
for (Operation* user : value.getUsers()) {
if (!visitedUsers.insert(user).second)
continue;
if (isSupportedAliasOp(user)) {
for (mlir::Value result : user->getResults()) {
pendingValues.push_back(result);
appendAliasDescription(interval.aliasesFollowed, result);
}
}
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
for (OpResult result : user->getResults()) {
OpOperand* tiedOperand = dpsOp.getTiedOpOperand(result);
if (!tiedOperand || tiedOperand->get() != value)
continue;
pendingValues.push_back(result);
appendAliasDescription(interval.aliasesFollowed, result);
}
}
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
for (auto [index, initArg] : llvm::enumerate(forOp.getInitArgs())) {
if (initArg != value)
continue;
pendingValues.push_back(forOp.getRegionIterArgs()[index]);
pendingValues.push_back(forOp.getResult(index));
appendAliasDescription(interval.aliasesFollowed, forOp.getRegionIterArgs()[index]);
appendAliasDescription(interval.aliasesFollowed, forOp.getResult(index));
if (parentLoop && forOp != parentLoop)
interval.escapesLoop = true;
}
}
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
if (!forOp) {
addFallbackReason(interval.fallbackReason, "yield without scf.for parent");
}
else {
for (auto [index, operand] : llvm::enumerate(yieldOp.getOperands())) {
if (operand != value)
continue;
pendingValues.push_back(forOp.getResult(index));
appendAliasDescription(interval.aliasesFollowed, forOp.getResult(index));
if (parentLoop && forOp == parentLoop)
interval.escapesLoop = true;
}
}
}
if (isRuntimeMemoryTouchOp(user)) {
uint64_t touchPosition = ordering.position.lookup(user);
if (!interval.hasRuntimeUse || touchPosition < interval.firstTouchPosition) {
interval.firstTouchPosition = touchPosition;
interval.firstTouchOp = user;
}
if (!interval.hasRuntimeUse || touchPosition > interval.lastTouchPosition) {
interval.lastTouchPosition = touchPosition;
interval.lastTouchOp = user;
}
OrderedTouchRange range = getEffectiveTouchRange(allocOp.getResult(), user, ordering);
interval.escapesLoop |= range.escapedLoop;
if (!interval.hasRuntimeUse) {
interval.start = range.start;
interval.end = range.end;
interval.startOp = range.startOp;
interval.endOp = range.endOp;
interval.hasRuntimeUse = true;
}
else {
if (range.start < interval.start) {
interval.start = range.start;
interval.startOp = range.startOp;
}
if (range.end > interval.end) {
interval.end = range.end;
interval.endOp = range.endOp;
}
}
continue;
}
if (isIgnoredLivenessUser(user))
continue;
addFallbackReason(interval.fallbackReason, "unhandled user op");
interval.endUsedFallback = true;
}
}
if (!interval.hasRuntimeUse) {
interval.startUsedAllocFallback = true;
interval.endUsedFallback = true;
interval.start = ordering.position.lookup(allocOp);
interval.end = fallbackEnd;
interval.startOp = allocOp;
interval.endOp = allocOp->getParentOp();
interval.firstTouchPosition = interval.start;
interval.lastTouchPosition = interval.end;
addFallbackReason(interval.fallbackReason, "no runtime memory touch");
return interval;
}
if (interval.endUsedFallback) {
interval.end = std::max(interval.end, fallbackEnd);
interval.endOp = allocOp->getParentOp();
}
return interval;
}
static FailureOr<size_t> getAllocSizeBytes(memref::AllocOp allocOp) {
auto type = dyn_cast<ShapedType>(allocOp.getType());
if (!type)
return failure();
auto checkedBytes = pim::getCheckedShapedTypeSizeInBytes(type, allocOp, "memory allocation byte size");
if (failed(checkedBytes))
return failure();
return pim::checkedSize(*checkedBytes, allocOp, "memory allocation byte size");
}
static bool intervalsOverlap(const LocalAllocInterval& lhs, const LocalAllocInterval& rhs) {
return !(lhs.end < rhs.start || rhs.end < lhs.start);
}
static uint64_t getSlotLogicalBytes(const PlannedPhysicalSlot& slot, ArrayRef<LocalAllocInterval> intervals) {
uint64_t slotLogicalBytes = 0;
for (size_t intervalIndex : slot.intervalIndices)
slotLogicalBytes += intervals[intervalIndex].size;
return slotLogicalBytes;
}
} // namespace
SmallVector<LocalAllocInterval, 0> onnx_mlir::buildLocalAllocIntervals(Operation* coreLikeOp,
std::optional<unsigned> lane) {
SmallVector<LocalAllocInterval, 0> intervals;
OperationOrdering ordering = buildOperationOrdering(coreLikeOp);
if (ordering.position.empty())
return intervals;
uint64_t fallbackEnd = ordering.nextPosition == 0 ? 0 : ordering.nextPosition - 1;
size_t nextIntervalId = 0;
coreLikeOp->walk([&](memref::AllocOp allocOp) {
auto checkedSize = getAllocSizeBytes(allocOp);
if (failed(checkedSize)) {
llvm::errs() << "Failed to compute local allocation size for value: ";
allocOp.getResult().print(llvm::errs());
llvm::errs() << "\n";
llvm_unreachable("Failed to compute local allocation size");
}
MemoryTouchInterval touchInterval = computeMemoryTouchInterval(allocOp, ordering, fallbackEnd);
LocalAllocInterval interval;
interval.id = nextIntervalId++;
interval.alloc = allocOp;
interval.key = getMemoryValueKey(allocOp.getResult(), lane);
interval.start = touchInterval.start;
interval.end = touchInterval.end;
interval.size = *checkedSize;
interval.startOp = touchInterval.startOp;
interval.endOp = touchInterval.endOp;
interval.firstTouchOp = touchInterval.firstTouchOp;
interval.lastTouchOp = touchInterval.lastTouchOp;
interval.firstTouchPosition = touchInterval.firstTouchPosition;
interval.lastTouchPosition = touchInterval.lastTouchPosition;
interval.startUsedAllocFallback = touchInterval.startUsedAllocFallback;
interval.endUsedFallback = touchInterval.endUsedFallback;
interval.hasRuntimeUse = touchInterval.hasRuntimeUse;
interval.insideNestedRegion = isNestedAllocation(coreLikeOp, allocOp);
interval.escapesLoop = touchInterval.escapesLoop;
interval.fallbackReason = std::move(touchInterval.fallbackReason);
interval.aliasesFollowed = std::move(touchInterval.aliasesFollowed);
intervals.push_back(std::move(interval));
});
return intervals;
}
SmallVector<PlannedPhysicalSlot, 0> onnx_mlir::planPhysicalSlots(MutableArrayRef<LocalAllocInterval> intervals) {
SmallVector<PlannedPhysicalSlot, 0> slots;
SmallVector<size_t> intervalOrder(intervals.size());
std::iota(intervalOrder.begin(), intervalOrder.end(), 0);
llvm::stable_sort(intervalOrder, [&](size_t lhsIndex, size_t rhsIndex) {
const LocalAllocInterval& lhs = intervals[lhsIndex];
const LocalAllocInterval& rhs = intervals[rhsIndex];
if (lhs.size != rhs.size)
return lhs.size > rhs.size;
if (lhs.start != rhs.start)
return lhs.start < rhs.start;
if (lhs.end != rhs.end)
return lhs.end < rhs.end;
return lhs.id < rhs.id;
});
for (size_t intervalIndex : intervalOrder) {
LocalAllocInterval& interval = intervals[intervalIndex];
PlannedPhysicalSlot* bestSlot = nullptr;
auto bestKey = std::tuple<size_t, size_t, size_t, size_t>(std::numeric_limits<size_t>::max(),
std::numeric_limits<size_t>::max(),
std::numeric_limits<size_t>::max(),
std::numeric_limits<size_t>::max());
for (size_t slotIndex = 0; slotIndex < slots.size(); ++slotIndex) {
PlannedPhysicalSlot& slot = slots[slotIndex];
bool compatible = true;
for (size_t otherIndex : slot.intervalIndices) {
if (intervalsOverlap(interval, intervals[otherIndex])) {
compatible = false;
break;
}
}
if (!compatible)
continue;
size_t resultingSize = std::max(slot.requiredSize, interval.size);
size_t growth = resultingSize - slot.requiredSize;
auto candidateKey =
std::tuple<size_t, size_t, size_t, size_t>(growth, resultingSize, slot.intervalIndices.size(), slot.id);
if (candidateKey < bestKey) {
bestKey = candidateKey;
bestSlot = &slot;
}
}
if (!bestSlot) {
slots.push_back({slots.size(), interval.size, interval.size, 0, {intervalIndex}});
interval.slotPlanIndex = slots.size() - 1;
interval.physicalSlotId = slots.back().id;
interval.physicalSlotSize = slots.back().requiredSize;
continue;
}
bestSlot->requiredSize = std::max(bestSlot->requiredSize, interval.size);
bestSlot->size = bestSlot->requiredSize;
bestSlot->intervalIndices.push_back(intervalIndex);
interval.slotPlanIndex = static_cast<size_t>(bestSlot - slots.data());
interval.physicalSlotId = bestSlot->id;
interval.physicalSlotSize = bestSlot->requiredSize;
}
return slots;
}
MemoryPlanArtifacts onnx_mlir::buildMemoryPlanArtifacts(Operation* coreLikeOp,
std::optional<unsigned> lane,
ArrayRef<LocalAllocInterval> intervals,
ArrayRef<PlannedPhysicalSlot> slots,
size_t addressLimit,
PimMemoryReportLevel reportLevel) {
MemoryPlanArtifacts artifacts;
uint64_t totalLogicalBytes = 0;
uint64_t totalPhysicalBytes = 0;
uint64_t fallbackIntervals = 0;
uint64_t noRuntimeTouchIntervals = 0;
uint64_t reusedAllocations = 0;
uint64_t nestedIntervals = 0;
uint64_t loopEscapingIntervals = 0;
size_t largestLogicalAllocation = 0;
size_t largestPhysicalSlot = 0;
size_t maximumAssignedAddress = 0;
for (const LocalAllocInterval& interval : intervals) {
totalLogicalBytes += interval.size;
largestLogicalAllocation = std::max(largestLogicalAllocation, interval.size);
maximumAssignedAddress = std::max(maximumAssignedAddress, interval.assignedAddress + interval.physicalSlotSize);
if (interval.startUsedAllocFallback || interval.endUsedFallback)
++fallbackIntervals;
if (!interval.hasRuntimeUse)
++noRuntimeTouchIntervals;
if (interval.insideNestedRegion)
++nestedIntervals;
if (interval.escapesLoop)
++loopEscapingIntervals;
}
for (const PlannedPhysicalSlot& slot : slots) {
totalPhysicalBytes += slot.size;
largestPhysicalSlot = std::max(largestPhysicalSlot, slot.size);
if (slot.intervalIndices.size() > 1)
reusedAllocations += slot.intervalIndices.size() - 1;
}
uint64_t savedBytes = totalLogicalBytes >= totalPhysicalBytes ? totalLogicalBytes - totalPhysicalBytes : 0;
double savedPercent =
totalLogicalBytes == 0 ? 0.0 : 100.0 * static_cast<double>(savedBytes) / static_cast<double>(totalLogicalBytes);
raw_string_ostream os(artifacts.textReport);
os << "=== PIM Memory Liveness Report ===\n";
os << "Op: " << coreLikeOp->getName() << "\n";
if (lane)
os << "Lane: " << *lane << "\n";
os << "Summary:\n";
os << " logical allocation bytes: " << formatReportMemory(totalLogicalBytes) << " (" << totalLogicalBytes << ")\n";
os << " physical allocation bytes: " << formatReportMemory(totalPhysicalBytes) << " (" << totalPhysicalBytes
<< ")\n";
os << " saved bytes: " << formatReportMemory(savedBytes) << " (" << savedBytes << ")\n";
os << " saved percent: " << format("%.2f%%", savedPercent) << "\n";
os << " intervals: " << intervals.size() << "\n";
os << " physical slots: " << slots.size() << "\n";
os << " reused allocations: " << reusedAllocations << "\n";
os << " fallback intervals: " << fallbackIntervals << "\n";
os << " intervals with no runtime memory touch: " << noRuntimeTouchIntervals << "\n";
os << " nested allocations: " << nestedIntervals << "\n";
os << " loop-escaping allocations: " << loopEscapingIntervals << "\n";
os << " largest logical allocation: " << largestLogicalAllocation << "\n";
os << " largest physical slot: " << largestPhysicalSlot << "\n";
os << " address limit: " << addressLimit << "\n";
os << " peak physical memory: " << formatReportMemory(maximumAssignedAddress) << " (" << maximumAssignedAddress
<< ")\n";
os << " maximum assigned address: " << maximumAssignedAddress << "\n";
os << "\nHow To Read:\n";
os << " `summary` only shows the strongest reuse cases and the worst offenders.\n";
os << " Use `--pim-memory-report=full` when you need the complete slot-by-slot and interval-by-interval dump.\n";
os << " Large single-use slots, fallback intervals, and nested single-use allocations are the best places\n";
os << " to inspect if allocations should be moved, sunk, or made easier to coalesce earlier in the pipeline.\n";
SmallVector<const PlannedPhysicalSlot*> reusedSlots;
SmallVector<const PlannedPhysicalSlot*> singleUseSlots;
for (const PlannedPhysicalSlot& slot : slots)
if (slot.intervalIndices.size() > 1)
reusedSlots.push_back(&slot);
else
singleUseSlots.push_back(&slot);
llvm::stable_sort(reusedSlots, [&](const PlannedPhysicalSlot* lhs, const PlannedPhysicalSlot* rhs) {
uint64_t lhsLogicalBytes = getSlotLogicalBytes(*lhs, intervals);
uint64_t rhsLogicalBytes = getSlotLogicalBytes(*rhs, intervals);
if (lhs->intervalIndices.size() != rhs->intervalIndices.size())
return lhs->intervalIndices.size() > rhs->intervalIndices.size();
if (lhsLogicalBytes != rhsLogicalBytes)
return lhsLogicalBytes > rhsLogicalBytes;
if (lhs->size != rhs->size)
return lhs->size > rhs->size;
return lhs->id < rhs->id;
});
llvm::stable_sort(singleUseSlots, [&](const PlannedPhysicalSlot* lhs, const PlannedPhysicalSlot* rhs) {
if (lhs->size != rhs->size)
return lhs->size > rhs->size;
return lhs->id < rhs->id;
});
constexpr size_t kSummaryReuseLimit = 6;
constexpr size_t kSummaryOffenderLimit = 10;
os << "\nBest Reuse:\n";
if (reusedSlots.empty()) {
os << " no slots were shared by multiple intervals\n";
}
else {
for (const PlannedPhysicalSlot* slot : ArrayRef(reusedSlots).take_front(kSummaryReuseLimit)) {
uint64_t slotLogicalBytes = getSlotLogicalBytes(*slot, intervals);
os << " slot #" << slot->id << " addr=" << slot->address << " size=" << formatReportMemory(slot->size)
<< " intervals=" << slot->intervalIndices.size() << " logical_sum=" << formatReportMemory(slotLogicalBytes)
<< "\n";
for (size_t intervalIndex : slot->intervalIndices) {
const LocalAllocInterval& interval = intervals[intervalIndex];
os << " #" << interval.id << " [" << interval.start << "," << interval.end << "]"
<< " logical=" << formatReportMemory(interval.size)
<< " first=" << summarizeOperation(interval.firstTouchOp, 40)
<< " last=" << summarizeOperation(interval.lastTouchOp, 40) << "\n";
}
}
}
os << "\nTop Offenders:\n";
bool printedAttention = false;
for (const PlannedPhysicalSlot* slot : ArrayRef(singleUseSlots).take_front(kSummaryOffenderLimit)) {
const LocalAllocInterval& interval = intervals[slot->intervalIndices.front()];
printedAttention = true;
os << " slot #" << slot->id << " is single-use"
<< " size=" << formatReportMemory(slot->size) << " interval=#" << interval.id
<< " value=" << summarizeValue(interval.key.value, 56) << "\n";
os << " first=" << summarizeOperation(interval.firstTouchOp, 40)
<< " last=" << summarizeOperation(interval.lastTouchOp, 40)
<< " nested=" << (interval.insideNestedRegion ? "yes" : "no")
<< " escapes_loop=" << (interval.escapesLoop ? "yes" : "no") << "\n";
}
size_t fallbackPrinted = 0;
for (const LocalAllocInterval& interval : intervals) {
if (!(interval.startUsedAllocFallback || interval.endUsedFallback) || fallbackPrinted >= kSummaryOffenderLimit)
continue;
printedAttention = true;
++fallbackPrinted;
os << " fallback interval #" << interval.id << " size=" << formatReportMemory(interval.size)
<< " value=" << summarizeValue(interval.key.value, 56) << "\n";
os << " reason: " << (interval.fallbackReason.empty() ? "<none>" : interval.fallbackReason) << "\n";
}
size_t nestedPrinted = 0;
for (const LocalAllocInterval& interval : intervals) {
if (nestedPrinted >= kSummaryOffenderLimit)
break;
if (!(interval.insideNestedRegion && slots[interval.slotPlanIndex].intervalIndices.size() == 1))
continue;
printedAttention = true;
++nestedPrinted;
os << " nested single-use interval #" << interval.id << " slot #" << interval.physicalSlotId
<< " size=" << formatReportMemory(interval.size) << " value=" << summarizeValue(interval.key.value, 56)
<< "\n";
os << " hint: move or sink this alloc inside the nested region if the IR allows it.\n";
}
if (!printedAttention)
os << " no obvious blockers detected in this core\n";
if (reportLevel == PimMemoryReportFull) {
os << "\nSlot Reuse:\n";
for (const PlannedPhysicalSlot& slot : slots) {
uint64_t slotLogicalBytes = getSlotLogicalBytes(slot, intervals);
os << " slot #" << slot.id << " addr=" << slot.address << " size=" << formatReportMemory(slot.size) << " ("
<< slot.size << ")"
<< " intervals=" << slot.intervalIndices.size() << " logical_sum=" << formatReportMemory(slotLogicalBytes)
<< "\n";
for (size_t intervalIndex : slot.intervalIndices) {
const LocalAllocInterval& interval = intervals[intervalIndex];
mlir::Value allocValue = interval.key.value;
os << " [" << interval.start << "," << interval.end << "]"
<< " #" << interval.id << " logical=" << formatReportMemory(interval.size)
<< " nested=" << (interval.insideNestedRegion ? "yes" : "no")
<< " escapes_loop=" << (interval.escapesLoop ? "yes" : "no")
<< " first=" << summarizeOperation(interval.firstTouchOp, 48)
<< " last=" << summarizeOperation(interval.lastTouchOp, 48) << "\n";
os << " value=" << summarizeValue(allocValue) << "\n";
}
}
}
if (reportLevel == PimMemoryReportFull) {
os << "\nInterval Details:\n";
for (const LocalAllocInterval& interval : intervals) {
const PlannedPhysicalSlot& slot = slots[interval.slotPlanIndex];
mlir::Value allocValue = interval.key.value;
Operation* definingOp = allocValue.getDefiningOp();
os << " #" << interval.id << " slot=" << slot.id << " live=[" << interval.start << "," << interval.end << "]"
<< " logical=" << formatReportMemory(interval.size)
<< " slot_size=" << formatReportMemory(interval.physicalSlotSize) << " addr=" << interval.assignedAddress
<< "\n";
os << " value=" << summarizeValue(allocValue, 88) << "\n";
os << " type=" << allocValue.getType() << "\n";
os << " loc="
<< summarizeLocation(definingOp ? definingOp->getLoc() : UnknownLoc::get(coreLikeOp->getContext())) << "\n";
os << " nested=" << (interval.insideNestedRegion ? "yes" : "no")
<< " escapes_loop=" << (interval.escapesLoop ? "yes" : "no")
<< " start_fallback=" << (interval.startUsedAllocFallback ? "yes" : "no")
<< " end_fallback=" << (interval.endUsedFallback ? "yes" : "no") << "\n";
os << " first_use=" << summarizeOperation(interval.firstTouchOp) << " @" << interval.firstTouchPosition
<< "\n";
os << " last_use=" << summarizeOperation(interval.lastTouchOp) << " @" << interval.lastTouchPosition << "\n";
os << " slot_peers=";
bool first = true;
for (size_t otherIndex : slot.intervalIndices) {
if (intervals[otherIndex].id == interval.id)
continue;
if (!first)
os << ", ";
os << "#" << intervals[otherIndex].id;
first = false;
}
if (first)
os << "<none>";
os << "\n";
if (!interval.fallbackReason.empty())
os << " fallback_reason=" << interval.fallbackReason << "\n";
if (!interval.aliasesFollowed.empty()) {
os << " aliases_followed=" << interval.aliasesFollowed.size() << "\n";
for (const std::string& alias : interval.aliasesFollowed)
os << " - " << abbreviate(collapseWhitespace(alias), 108) << "\n";
}
}
}
os.flush();
return artifacts;
}
+63
View File
@@ -0,0 +1,63 @@
#pragma once
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include <limits>
#include <optional>
#include <string>
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
namespace onnx_mlir {
struct LocalAllocInterval {
size_t id = 0;
mlir::memref::AllocOp alloc;
MemoryValueKey key;
uint64_t start = 0;
uint64_t end = 0;
size_t size = 0;
mlir::Operation* startOp = nullptr;
mlir::Operation* endOp = nullptr;
mlir::Operation* firstTouchOp = nullptr;
mlir::Operation* lastTouchOp = nullptr;
uint64_t firstTouchPosition = 0;
uint64_t lastTouchPosition = 0;
bool startUsedAllocFallback = false;
bool endUsedFallback = false;
bool hasRuntimeUse = false;
bool insideNestedRegion = false;
bool escapesLoop = false;
std::string fallbackReason;
llvm::SmallVector<std::string, 8> aliasesFollowed;
size_t slotPlanIndex = std::numeric_limits<size_t>::max();
size_t physicalSlotId = std::numeric_limits<size_t>::max();
size_t assignedAddress = 0;
size_t physicalSlotSize = 0;
};
struct PlannedPhysicalSlot {
size_t id = std::numeric_limits<size_t>::max();
size_t requiredSize = 0;
size_t size = 0;
size_t address = 0;
llvm::SmallVector<size_t, 8> intervalIndices;
};
llvm::SmallVector<LocalAllocInterval, 0> buildLocalAllocIntervals(mlir::Operation* coreLikeOp,
std::optional<unsigned> lane);
llvm::SmallVector<PlannedPhysicalSlot, 0> planPhysicalSlots(llvm::MutableArrayRef<LocalAllocInterval> intervals);
MemoryPlanArtifacts buildMemoryPlanArtifacts(mlir::Operation* coreLikeOp,
std::optional<unsigned> lane,
llvm::ArrayRef<LocalAllocInterval> intervals,
llvm::ArrayRef<PlannedPhysicalSlot> slots,
size_t addressLimit,
PimMemoryReportLevel reportLevel);
} // namespace onnx_mlir
+93
View File
@@ -0,0 +1,93 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include "Common/Support/CheckedArithmetic.hpp"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
using namespace llvm;
using namespace mlir;
namespace onnx_mlir {
namespace {} // namespace
WeightEmissionResult createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) {
auto coreWeightsDirPath = outputDirPath + "/weights";
auto error = sys::fs::create_directory(coreWeightsDirPath);
assert(!error && "Error creating weights directory");
size_t indexFileName = 0;
int64_t xbarSize = crossbarSize.getValue();
WeightEmissionResult result;
llvm::SmallVector<std::pair<ResolvedWeightView, std::string>, 16> materializedWeights;
auto materializeWeight = [&](const ResolvedWeightView& weightView) -> std::string {
if (auto it = llvm::find_if(materializedWeights, [&](const auto& entry) { return entry.first == weightView; });
it != materializedWeights.end())
return it->second;
auto globalOp = weightView.globalOp;
auto denseAttr = mlir::dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
assert(denseAttr && "Weight global must have dense initial value");
ArrayRef<int64_t> shape = weightView.shape;
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
std::error_code errorCode;
raw_fd_ostream weightFileStream(weightFilePath, errorCode, sys::fs::OF_None);
if (errorCode) {
errs() << "Error while opening weight file `" << weightFilePath << "`: " << errorCode.message() << '\n';
assert(errorCode);
}
uint64_t zero = 0;
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t elementIndex = weightView.offset + row * weightView.strides[0] + col * weightView.strides[1];
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
}
else {
weightFileStream.write(reinterpret_cast<const char*>(&zero), elementByteWidth);
}
}
}
weightFileStream.close();
materializedWeights.push_back({weightView, newFileName});
uint64_t weightBytes = pim::checkedMulOrCrash(
pim::checkedMulOrCrash(static_cast<size_t>(xbarSize), static_cast<size_t>(xbarSize), "weight element count"),
elementByteWidth,
"weight byte size");
result.totalWeightBytes = pim::checkedAddOrCrash(result.totalWeightBytes, weightBytes, "total weight bytes");
return newFileName;
};
for (const WeightFileRequest& request : requests) {
auto& coreFiles = result.mapCoreWeightToFileName[request.coreId];
coreFiles.reserve(request.weights.size());
for (const ResolvedWeightView& weight : request.weights)
coreFiles.push_back(materializeWeight(weight));
}
return result;
}
} // namespace onnx_mlir
+29
View File
@@ -0,0 +1,29 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <cstdint>
#include <string>
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
namespace onnx_mlir {
struct WeightFileRequest {
size_t coreId = 0;
llvm::SmallVector<ResolvedWeightView, 8> weights;
};
struct WeightEmissionResult {
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>> mapCoreWeightToFileName;
uint64_t totalWeightBytes = 0;
};
WeightEmissionResult createAndPopulateWeightFolder(llvm::ArrayRef<WeightFileRequest> requests,
llvm::StringRef outputDirPath);
} // namespace onnx_mlir
@@ -3,6 +3,12 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen)
add_pim_library(OMONNXToSpatial
Patterns.cpp
CompileTime.cpp
ONNXToSpatialVerifier.cpp
Patterns/Pre.cpp
Patterns/Post.cpp
Patterns/GeneratedConversion.cpp
Patterns/Math/Conv.cpp
Patterns/Math/Elementwise.cpp
Patterns/Math/Gemm.cpp
@@ -16,9 +22,15 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Gather.cpp
Patterns/Tensor/Resize.cpp
Patterns/Tensor/Reshape.cpp
Patterns/Tensor/Slice.cpp
Patterns/Tensor/Split.cpp
Patterns/Tensor/Transpose.cpp
ONNXToSpatialPass.cpp
Common.cpp
Common/AttributeUtils.cpp
Common/ComputeRegionBuilder.cpp
Common/IndexingUtils.cpp
Common/ShapeTilingUtils.cpp
Common/WeightMaterialization.cpp
EXCLUDE_FROM_OM_LIBS
@@ -26,6 +38,7 @@ add_pim_library(OMONNXToSpatial
ONNXToSpatialIncGen
LINK_LIBS PUBLIC
MLIRLinalgDialect
MLIRSCFDialect
MLIRTosaDialect
OMCompilerOptions
-137
View File
@@ -1,137 +0,0 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Casting.h"
#include <cassert>
#include <optional>
#include <utility>
#include "Common.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
SmallVector<Value> sliceTensor(
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(tensorToSlice);
assert("Invalid axis" && axis < shape.size());
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
sizes.reserve(shape.size());
for (const auto size : shape)
sizes.push_back(rewriter.getIndexAttr(size));
sizes[axis] = rewriter.getIndexAttr(sliceSize);
long length = shape[axis];
auto [numSlices, lastSliceSize] = ceilIntegerDivideWithRemainder(length, sliceSize);
SmallVector<Value> slices;
slices.reserve(numSlices);
for (int64_t i = 0; i < numSlices; i++) {
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
if (i == numSlices - 1 && lastSliceSize != 0)
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
slices.push_back(slice);
}
return slices;
}
SmallVector<Value>
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(vectorToSlice);
assert("Not a vector" && isVectorShape(shape));
size_t axis = shape[0] != 1 ? 0 : 1;
return sliceTensor(vectorToSlice, axis, sliceSize, rewriter, loc);
}
DenseMap<CoreId, SmallVector<Value>>
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) {
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
size_t coreId = sliceId / crossbarCountInCore;
slicesPerCore[coreId].push_back(slices[sliceId]);
}
return slicesPerCore;
}
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc) {
assert("Not a matrix" && isMatrixShape(getTensorShape(matrixToTile)));
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tiles;
SmallVector<Value> hSlices = sliceTensor(matrixToTile, 1, hSliceSize, rewriter, loc);
size_t numHSlices = hSlices.size();
for (size_t hSliceId = 0; hSliceId < numHSlices; hSliceId++) {
Value hSlice = hSlices[hSliceId];
SmallVector<Value> vSlices = sliceTensor(hSlice, 0, vSliceSize, rewriter, loc);
for (size_t vSliceId = 0; vSliceId < vSlices.size(); vSliceId++) {
size_t coreId = vSliceId / crossbarCountInCore;
Value vSlice = vSlices[vSliceId];
tiles[hSliceId][coreId].push_back(vSlice);
}
}
return tiles;
}
tensor::SplatOp
broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
Type elementType = oldType.getElementType();
int64_t shape[2] = {1, length};
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
SmallVector<Value> index(oldType.getRank(), zero);
auto elementValue = tensor::ExtractOp::create(rewriter, loc, scalarToBroadcast, index).getResult();
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
}
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
if (tensors.size() == 1)
return tensors[0];
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
SmallVector<Value> tensors2;
tensors2.reserve(tensors.size() / 2);
auto* currTensors = &tensors1;
auto* nextTensors = &tensors2;
while (currTensors->size() > 1) {
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
Value a = (*currTensors)[i];
Value b = (*currTensors)[i + 1];
rewriter.setInsertionPointAfterValue(b);
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
nextTensors->push_back(addedValue);
}
if (currTensors->size() % 2 == 1)
nextTensors->push_back(currTensors->back());
std::swap(currTensors, nextTensors);
nextTensors->clear();
}
assert(currTensors->size() == 1 && "Expected a single input at this point.");
return (*currTensors)[0];
}
}; // namespace onnx_mlir
-279
View File
@@ -1,279 +0,0 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cassert>
#include <type_traits>
#include <utility>
#include "llvm/ADT/SmallPtrSet.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
namespace onnx_mlir {
template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
using HSliceId = size_t;
using CoreId = size_t;
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return 1 + (ac - 1) / bc;
}
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return {ceilIntegerDivide(ac, bc), ac % bc};
}
template <class T>
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
template <class T>
bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
inline bool isWeightLikeComputeOperand(mlir::Value value) {
auto rankedType = mlir::dyn_cast<mlir::RankedTensorType>(value.getType());
if (!rankedType || !isMatrixShape(rankedType.getShape()))
return false;
llvm::SmallPtrSet<mlir::Operation*, 8> visited;
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (hasWeightAlways(definingOp))
return true;
if (auto extractSliceOp = mlir::dyn_cast<mlir::tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
continue;
}
if (auto expandShapeOp = mlir::dyn_cast<mlir::tensor::ExpandShapeOp>(definingOp)) {
value = expandShapeOp.getSrc();
continue;
}
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(definingOp)) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
continue;
}
return false;
}
return false;
}
namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...);
}
template <size_t>
using ValueArg = mlir::Value;
template <typename Fn, typename Seq>
struct InvokeWithBlockArgsResult;
template <typename Fn, size_t... Is>
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
};
template <typename Fn, typename Seq>
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
} // namespace detail
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
auto bodyResult =
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
}
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, mlir::LogicalResult>) {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
else {
static_assert(std::is_same_v<BodyResult, void>, "createSpatCompute body must return void or mlir::LogicalResult");
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
}
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
mlir::tensor::SplatOp broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
}; // namespace onnx_mlir
@@ -0,0 +1,23 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "AttributeUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
int64_t getI64Attr(ArrayAttr attr, size_t index) { return cast<IntegerAttr>(attr[index]).getInt(); }
int64_t getOptionalI64Attr(std::optional<ArrayAttr> attr, size_t index, int64_t defaultValue) {
return attr ? getI64Attr(*attr, index) : defaultValue;
}
llvm::SmallVector<int64_t> getI64ArrayAttrValues(ArrayAttr attr) {
llvm::SmallVector<int64_t> values;
values.reserve(attr.size());
for (Attribute value : attr)
values.push_back(cast<IntegerAttr>(value).getInt());
return values;
}
} // namespace onnx_mlir
@@ -0,0 +1,18 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <optional>
namespace onnx_mlir {
int64_t getI64Attr(mlir::ArrayAttr attr, size_t index);
int64_t getOptionalI64Attr(std::optional<mlir::ArrayAttr> attr, size_t index, int64_t defaultValue);
llvm::SmallVector<int64_t> getI64ArrayAttrValues(mlir::ArrayAttr attr);
} // namespace onnx_mlir
@@ -0,0 +1,9 @@
#pragma once
#include "AttributeUtils.hpp"
#include "ComputeRegionBuilder.hpp"
#include "IndexingUtils.hpp"
#include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -0,0 +1,39 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "ComputeRegionBuilder.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
Value sumTensors(ArrayRef<Value> tensors, ConversionPatternRewriter& rewriter) {
if (tensors.size() == 1)
return tensors[0];
SmallVector<Value> tensors1 = {tensors.begin(), tensors.end()};
SmallVector<Value> tensors2;
tensors2.reserve(tensors.size() / 2);
auto* currTensors = &tensors1;
auto* nextTensors = &tensors2;
while (currTensors->size() > 1) {
for (size_t i = 0; i < currTensors->size() - 1; i += 2) {
Value a = (*currTensors)[i];
Value b = (*currTensors)[i + 1];
rewriter.setInsertionPointAfterValue(b);
auto addedValue = spatial::SpatVAddOp::create(rewriter, a.getLoc(), a.getType(), a, b);
nextTensors->push_back(addedValue);
}
if (currTensors->size() % 2 == 1)
nextTensors->push_back(currTensors->back());
std::swap(currTensors, nextTensors);
nextTensors->clear();
}
assert(currTensors->size() == 1 && "Expected a single input at this point.");
return (*currTensors)[0];
}
} // namespace onnx_mlir
@@ -0,0 +1,267 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cassert>
#include <cstddef>
#include <limits>
#include <type_traits>
#include <utility>
#include "src/Accelerators/PIM/Common/Support/CheckedArithmetic.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) {
return mlir::ValueRange(block->getArguments()).drop_front(weightCount);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...);
}
template <size_t>
using ValueArg = mlir::Value;
template <typename Fn, typename Seq>
struct InvokeWithBlockArgsResult;
template <typename Fn, size_t... Is>
struct InvokeWithBlockArgsResult<Fn, std::index_sequence<Is...>> {
using type = std::invoke_result_t<Fn, ValueArg<Is>...>;
};
template <typename Fn, typename Seq>
using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::type;
template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
struct SpatComputeBatchBodyArgs {
mlir::Value lane;
mlir::ValueRange weights;
mlir::ValueRange inputs;
mlir::ValueRange outputs;
};
} // namespace detail
template <typename RewriterT>
inline mlir::Value createSpatConcat(RewriterT& rewriter, mlir::Location loc, int64_t axis, mlir::ValueRange inputs) {
assert(!inputs.empty() && "spat.concat requires at least one input");
if (inputs.size() == 1)
return inputs.front();
auto firstType = mlir::cast<mlir::RankedTensorType>(inputs.front().getType());
auto outputShape = llvm::to_vector(firstType.getShape());
int64_t concatDimSize = 0;
bool concatDimDynamic = false;
for (mlir::Value input : inputs) {
auto inputType = mlir::cast<mlir::RankedTensorType>(input.getType());
assert(inputType.getRank() == firstType.getRank() && "spat.concat expects same-rank inputs");
if (mlir::ShapedType::isDynamic(inputType.getDimSize(axis)))
concatDimDynamic = true;
else
concatDimSize += inputType.getDimSize(axis);
}
outputShape[axis] = concatDimDynamic ? mlir::ShapedType::kDynamic : concatDimSize;
auto outputType = mlir::RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
return spatial::SpatConcatOp::create(rewriter, loc, outputType, rewriter.getI64IntegerAttr(axis), inputs).getOutput();
}
/// Builds a `spat.compute` with a fixed number of SSA inputs and erases it if
/// the body callback reports failure.
template <size_t NumInputs, typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
assert(inputs.size() == NumInputs && "NumInputs must match the number of input values");
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value weight : weights)
block->addArgument(weight.getType(), loc);
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, void>) {
detail::invokeWithValues(std::forward<BodyFn>(body),
detail::getInputBlockArgs(block, weights.size()),
std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult = detail::invokeWithValues(std::forward<BodyFn>(body),
detail::getInputBlockArgs(block, weights.size()),
std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
}
/// Builds a `spat.compute` whose body consumes the block arguments as a single
/// `ValueRange`, which is convenient for variadic reductions/concats.
template <typename RewriterT, typename BodyFn>
auto createSpatCompute(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value weight : weights)
block->addArgument(weight.getType(), loc);
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
computeOp.getBody().push_back(block);
rewriter.setInsertionPointToStart(block);
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult = std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(mlir::failure());
}
rewriter.setInsertionPointAfter(computeOp);
return mlir::FailureOr<spatial::SpatCompute>(computeOp);
}
}
template <typename RewriterT, typename BodyFn>
auto createSpatComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
auto laneCountAttr = pim::getCheckedI32Attr(rewriter, loc, laneCount, "spatial compute_batch lane count");
if (mlir::failed(laneCountAttr))
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
auto batchOp = spatial::SpatComputeBatch::create(rewriter, loc, resultTypes, *laneCountAttr, weights, inputs);
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
blockArgTypes.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
blockArgLocs.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
for (mlir::Value weight : weights) {
blockArgTypes.push_back(weight.getType());
blockArgLocs.push_back(weight.getLoc());
}
for (mlir::Value input : inputs) {
blockArgTypes.push_back(input.getType());
blockArgLocs.push_back(input.getLoc());
}
for (mlir::Type resultType : resultTypes) {
blockArgTypes.push_back(resultType);
blockArgLocs.push_back(loc);
}
auto* block =
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), mlir::TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToStart(block);
detail::SpatComputeBatchBodyArgs args {
block->getArgument(0),
mlir::ValueRange(block->getArguments()).slice(1, weights.size()),
mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()),
mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size())};
using BodyResult = std::invoke_result_t<BodyFn, detail::SpatComputeBatchBodyArgs>;
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(args);
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
}
else {
auto bodyResult = std::forward<BodyFn>(body)(args);
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(batchOp);
rewriter.eraseOp(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
}
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
}
}
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
mlir::Value dest,
mlir::ArrayRef<mlir::OpFoldResult> offsets,
mlir::ArrayRef<mlir::OpFoldResult> sizes,
mlir::ArrayRef<mlir::OpFoldResult> strides) {
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
mlir::tensor::ParallelInsertSliceOp::create(rewriter, loc, source, dest, offsets, sizes, strides);
}
template <typename BodyFn>
mlir::Value materializeOrComputeUnary(mlir::Value input,
mlir::RankedTensorType resultType,
mlir::PatternRewriter& rewriter,
mlir::Location loc,
BodyFn&& build) {
auto&& buildFn = build;
if (isCompileTimeComputable(input))
return buildFn(input);
auto computeOp = createSpatCompute<1>(
rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) {
mlir::Value result = buildFn(computeInput);
spatial::SpatYieldOp::create(rewriter, loc, result);
});
return computeOp.getResult(0);
}
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
} // namespace onnx_mlir
@@ -0,0 +1,45 @@
#include <algorithm>
#include "IndexingUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank) {
int64_t normalizedAxis = normalizeAxis(axis, rank);
if (normalizedAxis < 0 || normalizedAxis >= rank)
return failure();
return normalizedAxis;
}
int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
static SmallVector<int64_t> normalizeAxesImpl(std::optional<ArrayAttr> axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes;
if (!axesAttr) {
normalizedAxes.reserve(rank);
for (int64_t axis = 0; axis < rank; ++axis)
normalizedAxes.push_back(axis);
}
else {
normalizedAxes.reserve(axesAttr->size());
for (Attribute attr : *axesAttr)
normalizedAxes.push_back(normalizeAxis(cast<IntegerAttr>(attr).getInt(), rank));
llvm::sort(normalizedAxes);
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
}
return normalizedAxes;
}
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes = normalizeAxesImpl(axesAttr, rank);
for (int64_t axis : normalizedAxes)
if (axis < 0 || axis >= rank)
return failure();
return normalizedAxes;
}
} // namespace onnx_mlir
@@ -0,0 +1,20 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SmallVector.h"
#include <optional>
namespace onnx_mlir {
int64_t normalizeAxis(int64_t axis, int64_t rank);
mlir::FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank);
int64_t normalizeIndex(int64_t index, int64_t dimSize);
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
} // namespace onnx_mlir
@@ -0,0 +1,179 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "llvm/ADT/SmallVector.h"
#include <functional>
#include "IndexingUtils.hpp"
#include "ShapeTilingUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
using namespace mlir;
namespace onnx_mlir {
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
bool hasStaticPositiveShape(RankedTensorType type) {
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
}
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
SmallVector<int64_t> permutedShape;
permutedShape.reserve(permutation.size());
for (int64_t axis : permutation)
permutedShape.push_back(shape[axis]);
return permutedShape;
}
SmallVector<int64_t> invertPermutation(ArrayRef<int64_t> permutation) {
SmallVector<int64_t> inversePermutation(permutation.size());
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
return inversePermutation;
}
FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<ArrayAttr> permAttr, int64_t rank) {
SmallVector<int64_t> permutation;
if (!permAttr) {
permutation.reserve(rank);
for (int64_t dim = rank - 1; dim >= 0; --dim)
permutation.push_back(dim);
return permutation;
}
if (static_cast<int64_t>(permAttr->size()) != rank)
return failure();
permutation.reserve(permAttr->size());
SmallVector<bool> seen(rank, false);
for (IntegerAttr attr : permAttr->getAsRange<IntegerAttr>()) {
int64_t axis = attr.getInt();
if (axis < 0 || axis >= rank || seen[axis])
return failure();
seen[axis] = true;
permutation.push_back(axis);
}
return permutation;
}
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
}
SmallVector<OpFoldResult> getZeroOffsets(PatternRewriter& rewriter, int64_t rank) {
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(0));
}
SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int64_t> shape) {
SmallVector<OpFoldResult> sizes;
sizes.reserve(shape.size());
for (int64_t dim : shape)
sizes.push_back(rewriter.getIndexAttr(dim));
return sizes;
}
SmallVector<Value> sliceTensor(
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(tensorToSlice);
assert("Invalid axis" && axis < shape.size());
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, shape.size());
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, shape);
sizes[axis] = rewriter.getIndexAttr(sliceSize);
long length = shape[axis];
auto [numSlices, lastSliceSize] = ceilIntegerDivideWithRemainder(length, sliceSize);
SmallVector<Value> slices;
slices.reserve(numSlices);
for (int64_t i = 0; i < numSlices; i++) {
offsets[axis] = rewriter.getIndexAttr(i * sliceSize);
int64_t currentSliceSize = sliceSize;
if (i == numSlices - 1 && lastSliceSize != 0) {
currentSliceSize = lastSliceSize;
sizes[axis] = rewriter.getIndexAttr(lastSliceSize);
}
SmallVector<int64_t> sliceShape(shape.begin(), shape.end());
sliceShape[axis] = currentSliceSize;
auto sliceType =
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
Value slice;
if (isCompileTimeComputable(tensorToSlice)) {
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
}
else {
auto sliceCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {sliceType}, {}, ValueRange {tensorToSlice}, [&](Value input) {
Value computedSlice = tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
spatial::SpatYieldOp::create(rewriter, loc, computedSlice);
});
slice = sliceCompute.getResult(0);
}
slices.push_back(slice);
}
return slices;
}
SmallVector<Value>
sliceVector(const Value& vectorToSlice, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(vectorToSlice);
assert("Not a vector" && isVectorShape(shape));
size_t axis = shape[0] != 1 ? 0 : 1;
return sliceTensor(vectorToSlice, axis, sliceSize, rewriter, loc);
}
DenseMap<CoreId, SmallVector<Value>>
sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewriter& rewriter, Location loc) {
SmallVector<Value> slices = sliceVector(vectorToSlice, crossbarSize, rewriter, loc);
DenseMap<CoreId, SmallVector<Value>> slicesPerCore;
for (size_t sliceId = 0; sliceId < slices.size(); sliceId++) {
size_t coreId = sliceId / crossbarCountInCore;
slicesPerCore[coreId].push_back(slices[sliceId]);
}
return slicesPerCore;
}
Value extractAxisSlice(
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
auto sourceType = cast<RankedTensorType>(source.getType());
SmallVector<int64_t> resultShape(sourceType.getShape());
resultShape[axis] = size;
auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding());
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, sourceType.getRank());
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(size);
return tensor::ExtractSliceOp::create(
rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
.getResult();
}
Value insertStaticSlice(
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
auto sourceType = cast<RankedTensorType>(source.getType());
return tensor::InsertSliceOp::create(rewriter,
loc,
source,
dest,
offsets,
getStaticSizes(rewriter, sourceType.getShape()),
getUnitStrides(rewriter, sourceType.getRank()))
.getResult();
}
} // namespace onnx_mlir
@@ -0,0 +1,114 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include <cassert>
#include <cstddef>
#include <optional>
#include <type_traits>
#include <utility>
namespace onnx_mlir {
using HSliceId = size_t;
using CoreId = size_t;
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr C ceilIntegerDivide(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return 1 + (ac - 1) / bc;
}
template <class A, class B, class C = std::common_type_t<A, B>>
constexpr std::pair<C, C> ceilIntegerDivideWithRemainder(A a, B b) {
static_assert(std::is_integral_v<A>, "A must be an integer type");
static_assert(std::is_integral_v<B>, "B must be an integer type");
C ac = static_cast<C>(a);
C bc = static_cast<C>(b);
return {ceilIntegerDivide(ac, bc), ac % bc};
}
template <class T>
bool isVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && (shape[0] == 1 || shape[1] == 1);
}
template <class T>
bool isMatrixShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2;
}
template <class T>
bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
auto lhsType = mlir::dyn_cast<mlir::RankedTensorType>(lhs.getType());
auto rhsType = mlir::dyn_cast<mlir::RankedTensorType>(rhs.getType());
return lhsType && rhsType && lhsType.hasStaticShape() && rhsType.hasStaticShape()
&& lhsType.getShape() == rhsType.getShape();
}
bool hasStaticPositiveShape(mlir::ArrayRef<int64_t> shape);
bool hasStaticPositiveShape(mlir::RankedTensorType type);
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, mlir::ArrayRef<int64_t> shape);
/// Slices a statically shaped tensor along one axis into contiguous pieces of
/// at most `sliceSize` elements.
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
size_t axis,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
int64_t sliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
/// Partitions one logical vector into per-core crossbar-sized slices using the
/// current PIM target geometry.
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
mlir::Value extractAxisSlice(
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
mlir::Value dest,
llvm::ArrayRef<mlir::OpFoldResult> offsets);
} // namespace onnx_mlir
@@ -0,0 +1,115 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
bool isWeightLikeComputeOperand(Value value) {
auto rankedType = dyn_cast<RankedTensorType>(value.getType());
if (!rankedType || !isMatrixShape(rankedType.getShape()))
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
while (auto* definingOp = value.getDefiningOp()) {
if (!visited.insert(definingOp).second)
return false;
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp) || hasWeightAlways(definingOp))
return true;
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
value = extractSliceOp.getSource();
continue;
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
value = expandShapeOp.getSrc();
continue;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
value = transposeOp.getInput();
continue;
}
return false;
}
return false;
}
FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewriter, IRMapping& mapper) {
if (auto mapped = mapper.lookupOrNull(value))
return cast<Value>(mapped);
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return failure();
if (isa<arith::ConstantOp, ONNXConstantOp>(definingOp)) {
auto tensorType = dyn_cast<RankedTensorType>(value.getType());
if (!tensorType || !tensorType.hasStaticShape())
return failure();
SmallVector<OpFoldResult> offsets(tensorType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(tensorType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(tensorType.getRank());
for (int64_t dim : tensorType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
auto referencedValue =
tensor::ExtractSliceOp::create(rewriter, value.getLoc(), tensorType, value, offsets, sizes, strides);
mapper.map(value, referencedValue.getResult());
return referencedValue.getResult();
}
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(definingOp))
return failure();
IRMapping localMapper;
for (Value operand : definingOp->getOperands()) {
if (auto mapped = mapper.lookupOrNull(operand)) {
localMapper.map(operand, cast<Value>(mapped));
continue;
}
if (isWeightLikeComputeOperand(operand)) {
auto clonedOperand = materializeWeightLikeValueInBlock(operand, rewriter, mapper);
if (failed(clonedOperand))
return failure();
localMapper.map(operand, *clonedOperand);
continue;
}
localMapper.map(operand, operand);
}
Operation* clonedOp = rewriter.clone(*definingOp, localMapper);
for (auto [oldResult, newResult] : llvm::zip(definingOp->getResults(), clonedOp->getResults()))
mapper.map(oldResult, newResult);
auto mapped = mapper.lookupOrNull(value);
if (!mapped)
return failure();
return cast<Value>(mapped);
}
} // namespace onnx_mlir
@@ -0,0 +1,18 @@
#pragma once
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
/// Returns true when a matrix-valued compute operand is ultimately backed by a
/// weight-marked constant/view chain and can be promoted into weights.
bool isWeightLikeComputeOperand(mlir::Value value);
/// Rebuilds the view/transpose chain of a promoted weight operand inside a new
/// compute body while reusing already-materialized intermediate values.
llvm::FailureOr<mlir::Value>
materializeWeightLikeValueInBlock(mlir::Value value, mlir::IRRewriter& rewriter, mlir::IRMapping& mapper);
} // namespace onnx_mlir
@@ -0,0 +1,299 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include <utility>
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
return llvm::all_of(extractSliceOp.getStaticStrides(), [](int64_t stride) { return stride == 1; });
}
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
return llvm::all_of(extractOp.getIndices(), [](Value index) { return matchConstantIndexValue(index).has_value(); });
}
static bool isStaticTensorResult(Operation* op) {
return llvm::all_of(op->getResultTypes(), [](Type type) {
auto shapedType = dyn_cast<ShapedType>(type);
return shapedType && shapedType.hasStaticShape();
});
}
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!tensorType)
return failure();
int64_t rank = tensorType.getRank();
if (static_cast<int64_t>(perms.size()) != rank)
return failure();
llvm::SmallBitVector seen(rank);
SmallVector<int64_t> transposedShape;
transposedShape.reserve(rank);
for (int64_t perm : perms) {
if (perm < 0 || perm >= rank || seen.test(perm))
return failure();
seen.set(perm);
transposedShape.push_back(tensorType.getShape()[perm]);
}
auto transposedType = RankedTensorType::get(transposedShape, tensorType.getElementType(), tensorType.getEncoding());
if (denseAttr.isSplat())
return DenseElementsAttr::get(transposedType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> originalValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> transposedValues(originalValues.size());
SmallVector<int64_t> originalStrides = computeRowMajorStrides(tensorType.getShape());
SmallVector<int64_t> transposedStrides = computeRowMajorStrides(transposedShape);
SmallVector<int64_t> originalIndices(rank);
for (auto [linearIndex, value] : llvm::enumerate(originalValues)) {
int64_t remaining = static_cast<int64_t>(linearIndex);
for (int64_t dim = 0; dim < rank; ++dim) {
originalIndices[dim] = remaining / originalStrides[dim];
remaining %= originalStrides[dim];
}
int64_t transposedLinearIndex = 0;
for (int64_t dim = 0; dim < rank; ++dim)
transposedLinearIndex += originalIndices[perms[dim]] * transposedStrides[dim];
transposedValues[transposedLinearIndex] = value;
}
return DenseElementsAttr::get(transposedType, transposedValues);
}
static FailureOr<DenseElementsAttr> reshapeDenseElements(DenseElementsAttr denseAttr, RankedTensorType resultType) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!sourceType || !resultType || sourceType.getNumElements() != resultType.getNumElements())
return failure();
if (denseAttr.isSplat())
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> values(denseAttr.getValues<Attribute>());
return DenseElementsAttr::get(resultType, values);
}
static FailureOr<DenseElementsAttr> extractSliceDenseElements(DenseElementsAttr denseAttr,
tensor::ExtractSliceOp extractSliceOp) {
auto sourceType = dyn_cast<RankedTensorType>(denseAttr.getType());
auto resultType = dyn_cast<RankedTensorType>(extractSliceOp.getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
ArrayRef<int64_t> offsets = extractSliceOp.getStaticOffsets();
ArrayRef<int64_t> sizes = extractSliceOp.getStaticSizes();
ArrayRef<int64_t> strides = extractSliceOp.getStaticStrides();
if (llvm::any_of(offsets, [](int64_t value) { return ShapedType::isDynamic(value); })
|| llvm::any_of(sizes, [](int64_t value) { return ShapedType::isDynamic(value); })
|| llvm::any_of(strides, [](int64_t stride) { return ShapedType::isDynamic(stride) || stride != 1; }))
return failure();
if (denseAttr.isSplat())
return DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>());
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
SmallVector<int64_t> sourceStrides = computeRowMajorStrides(sourceType.getShape());
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
SmallVector<Attribute> resultValues;
resultValues.reserve(resultType.getNumElements());
for (int64_t linearIndex = 0; linearIndex < resultType.getNumElements(); ++linearIndex) {
int64_t remaining = linearIndex;
int64_t sourceLinearIndex = 0;
for (int64_t dim = 0; dim < resultType.getRank(); ++dim) {
const int64_t resultIndex = resultStrides.empty() ? 0 : remaining / resultStrides[dim];
remaining = resultStrides.empty() ? 0 : remaining % resultStrides[dim];
sourceLinearIndex += (offsets[dim] + resultIndex) * sourceStrides[dim];
}
resultValues.push_back(sourceValues[sourceLinearIndex]);
}
return DenseElementsAttr::get(resultType, resultValues);
}
static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
if (auto constantOp = value.getDefiningOp<arith::ConstantOp>())
return dyn_cast<DenseElementsAttr>(constantOp.getValue());
if (auto constantOp = value.getDefiningOp<ONNXConstantOp>())
return dyn_cast_or_null<DenseElementsAttr>(constantOp.getValueAttr());
return nullptr;
}
static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
auto* definingOp = value.getDefiningOp();
if (!definingOp || !visited.insert(definingOp).second)
return nullptr;
// Rebuild dense attributes through view-only host-foldable chains so later
// lowering stages can still recognize grouped/sliced constants.
if (auto denseAttr = getDirectDenseConstantAttr(value))
return denseAttr;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getData(), visited);
if (!inputAttr)
return nullptr;
SmallVector<int64_t> perm;
perm.reserve(transposeOp.getPermAttr().size());
for (IntegerAttr attr : transposeOp.getPermAttr().getAsRange<IntegerAttr>())
perm.push_back(attr.getInt());
auto transposedAttr = transposeDenseElements(inputAttr, perm);
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
}
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getInput(), visited);
if (!inputAttr)
return nullptr;
SmallVector<int64_t> perm(transposeOp.getPermutation().begin(), transposeOp.getPermutation().end());
auto transposedAttr = transposeDenseElements(inputAttr, perm);
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
return succeeded(reshapedAttr) ? *reshapedAttr : nullptr;
}
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
if (!inputAttr)
return nullptr;
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
return succeeded(slicedAttr) ? *slicedAttr : nullptr;
}
return nullptr;
}
static std::optional<CompileTimeSource>
getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited, size_t chainLength = 0) {
if (!op)
return std::nullopt;
if (!visited.insert(op).second)
return {
{op, chainLength}
};
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
return {
{op, chainLength}
};
chainLength += 1;
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
return hasConstantIndices(extractOp)
? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength)
: std::nullopt;
if (!isStaticTensorResult(op))
return std::nullopt;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength);
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(op))
return getCompileTimeSourceImpl(transposeOp.getInput().getDefiningOp(), visited, chainLength);
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength);
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength);
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return hasStaticUnitStrides(extractSliceOp)
? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength)
: std::nullopt;
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
return getCompileTimeSourceImpl(splatOp.getInput().getDefiningOp(), visited, chainLength);
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
return getCompileTimeSourceImpl(extractRowsOp.getInput().getDefiningOp(), visited, chainLength);
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
std::optional<CompileTimeSource> res = {};
for (auto operandValue : concatOp.getOperands()) {
auto partialRes = getCompileTimeSourceImpl(operandValue.getDefiningOp(), visited, chainLength);
if (!partialRes)
return std::nullopt;
if (!res) {
res = partialRes;
continue;
}
if (res->chainLength < partialRes->chainLength)
res = partialRes;
}
return res;
}
return std::nullopt;
}
} // namespace
std::optional<CompileTimeSource> getCompileTimeSource(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited;
return getCompileTimeSourceImpl(op, visited);
}
bool isCompileTimeComputable(Value value) {
auto* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
return getCompileTimeSourceImpl(definingOp, visited).has_value();
}
bool isCompileTimeOp(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited;
return getCompileTimeSourceImpl(op, visited).has_value();
}
DenseElementsAttr getHostConstDenseElementsAttr(Value value) {
llvm::SmallPtrSet<Operation*, 8> visited;
return getHostConstantDenseElementsAttrImpl(value, visited);
}
} // namespace onnx_mlir
@@ -0,0 +1,22 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
struct CompileTimeSource {
mlir::Operation* source;
size_t chainLength;
};
std::optional<CompileTimeSource> getCompileTimeSource(mlir::Operation* op);
bool isCompileTimeComputable(mlir::Value value);
bool isCompileTimeOp(mlir::Operation* op);
mlir::DenseElementsAttr getHostConstDenseElementsAttr(mlir::Value value);
} // namespace onnx_mlir

Some files were not shown because too many files have changed in this diff Show More