56 Commits

Author SHA1 Message Date
NiccoloN b678e55d3c compact memory contiguity with for loops
Validate Operations / validate-operations (push) Waiting to run
2026-05-31 18:47:59 +02:00
NiccoloN ab63498f3f normalize affine arithmetic helpers
Validate Operations / validate-operations (push) Has been cancelled
2026-05-30 16:37:28 +02:00
NiccoloN 7c3943bd06 Merge remote-tracking branch 'origin/refactorone' into refactorone
Validate Operations / validate-operations (push) Has been cancelled
# Conflicts:
#	src/PIM/Dialect/Pim/Transforms/Bufferization/PimBufferizationPass.cpp
2026-05-30 16:12:42 +02:00
NiccoloN c0238c0d06 fix high memory usage caused by MaterializeMergeSchedule.cpp with more robust code 2026-05-30 16:12:06 +02:00
NiccoloN ff36729140 centralize logic for materializing contiguous memory into bufferization
fix codegen symlinks overwrite
remove deprecated pim memcp_hd_batch op
2026-05-30 16:09:58 +02:00
NiccoloN cf93caecd5 centralize logic for materializing contiguous memory into bufferization
Validate Operations / validate-operations (push) Has been cancelled
fix codegen symlinks overwrite
remove deprecated pim memcp_hd_batch op
2026-05-30 15:54:24 +02:00
NiccoloN 2d5b03c08f automatic code reformat
Validate Operations / validate-operations (push) Has been cancelled
2026-05-29 19:21:37 +02:00
NiccoloN a41f694cf0 batched matmul pattern
Validate Operations / validate-operations (push) Has been cancelled
add conv helpers
new validation tests for matmul
2026-05-29 19:09:48 +02:00
NiccoloN 8bb0babf1b finish helper refactoring
Validate Operations / validate-operations (push) Has been cancelled
use uniqued constant helpers everywhere
materialize transposed constants directly
2026-05-29 17:05:45 +02:00
ilgeco 819d8af0f7 Refactor + ReduceMean batched
Validate Operations / validate-operations (push) Has been cancelled
2026-05-29 15:57:13 +02:00
ilgeco 832bd7f1f7 Transpose and Refactor of Patterns
Validate Operations / validate-operations (push) Has been cancelled
2026-05-29 13:23:31 +02:00
ilgeco 82b44a6387 New Onnx test gemm model 2026-05-29 11:41:30 +02:00
ilgeco 7fcc765d6e New Onnx Test model 2026-05-29 11:37:17 +02:00
ilgeco f34698a2b6 Validate new option for compile only
Validate Operations / validate-operations (push) Has been cancelled
2026-05-28 22:59:26 +02:00
ilgeco 1ab489fe0a Dynamic gemm/conv 2026-05-28 18:00:14 +02:00
ilgeco cbf7b235f1 pim-simulator now support usize addresses
Validate Operations / validate-operations (push) Has been cancelled
2026-05-28 17:03:19 +02:00
NiccoloN 00414dd1d9 add verification of communication invariants at the end of spatial
Validate Operations / validate-operations (push) Has been cancelled
remove dead logic
2026-05-27 19:17:48 +02:00
NiccoloN 783dffe553 fix scheduling cost model
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 17:14:19 +02:00
NiccoloN 874a2f53e6 automatic code reformat
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 16:39:56 +02:00
NiccoloN 4bdaa57656 simplify affine maps to constants where possible
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 16:39:27 +02:00
NiccoloN 1a5d7d2a3f fix bufferization and weight emission after new gemm patterns
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 16:15:10 +02:00
ilgeco 013ae0ac2a Update README and AGENTS
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 15:09:30 +02:00
ilgeco c6b02af7a9 Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone 2026-05-27 14:32:51 +02:00
ilgeco d2048bd394 Add to gitignore 2026-05-27 14:32:47 +02:00
NiccoloN 158f0f0c54 update AGENTS.md
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 14:32:04 +02:00
NiccoloN 532cac8246 commit AGENTS.md
Validate Operations / validate-operations (push) Has been cancelled
2026-05-27 14:07:34 +02:00
NiccoloN d609e84054 teh only weight (WIP)
Validate Operations / validate-operations (push) Has been cancelled
2026-05-26 18:42:14 +02:00
NiccoloN addfc8a86e remove other dead logic
Validate Operations / validate-operations (push) Has been cancelled
2026-05-25 21:22:08 +02:00
NiccoloN 0f240af271 cleanup unused channel operations and related logic
Validate Operations / validate-operations (push) Has been cancelled
2026-05-25 20:58:51 +02:00
ilgeco bdc4ca33f3 No extract no more
Validate Operations / validate-operations (push) Has been cancelled
2026-05-25 18:19:43 +02:00
ilgeco b79c333c6c Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone 2026-05-25 15:44:40 +02:00
ilgeco eea9261c7b Bye Bye DCP 2026-05-25 15:44:30 +02:00
NiccoloN e8a08f6dd0 faster pim VerificationPass.cpp and pim code emission
Validate Operations / validate-operations (push) Has been cancelled
2026-05-25 15:24:12 +02:00
NiccoloN 4855a2e105 add verification of static weights in spatial
Validate Operations / validate-operations (push) Has been cancelled
2026-05-24 12:00:42 +02:00
NiccoloN 3a7a832198 MaterializeMergeSchedule.cpp fix for yolo11_depth_18 2026-05-24 11:54:00 +02:00
NiccoloN 48ca6bd28d speed fix with a simple cache
Validate Operations / validate-operations (push) Has been cancelled
2026-05-24 10:52:28 +02:00
NiccoloN f595cc6ffd fix high memory usage in IR 2026-05-24 10:41:47 +02:00
NiccoloN c734f1b37e better MaterializeMergeSchedule.cpp that emits much more compact IR
Validate Operations / validate-operations (push) Has been cancelled
add support for other constant-time arith ops in codegen
2026-05-24 10:10:24 +02:00
NiccoloN b79ce8eeaa use affine dialect to express simple constant progressions
Validate Operations / validate-operations (push) Has been cancelled
run dce at the end of MaterializeMergeSchedule to get rid of unused constants
2026-05-23 14:25:34 +02:00
NiccoloN 76a37e198f better MaterializeMergeSchedule.cpp with both send and receive compaction in for loops
Validate Operations / validate-operations (push) Has been cancelled
2026-05-23 11:17:36 +02:00
NiccoloN 7f3c7464b4 update cost model of batch lanes to consider only a slice of the shared batch input
Validate Operations / validate-operations (push) Has been cancelled
2026-05-22 22:16:19 +02:00
NiccoloN c77ffa9c56 better MaterializeMergeSchedule.cpp with %lane indexed batch computes
support for tensors of index values
2026-05-22 21:52:28 +02:00
NiccoloN 495186503c fix cmake magic once again 2026-05-22 19:21:56 +02:00
NiccoloN 2c1da813b5 fix much stuff 2026-05-22 18:53:38 +02:00
NiccoloN 8337a11ce9 automatic code reformat 2026-05-22 15:23:48 +02:00
ilgeco d136136d22 Fix add of input in random order for compute_batch
Validate Operations / validate-operations (push) Has been cancelled
2026-05-22 15:21:02 +02:00
NiccoloN 074eb183c7 saner SpatialToPimPass architecture
Validate Operations / validate-operations (push) Has been cancelled
2026-05-22 07:27:54 +02:00
NiccoloN 43ed3914b8 better MaterializeMergeSchedule.cpp (something still broken downstream)
Validate Operations / validate-operations (push) Has been cancelled
2026-05-22 06:56:39 +02:00
ilgeco 6aaf1c0870 Merge branch 'refactorone' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor into refactorone
Validate Operations / validate-operations (push) Has been cancelled
2026-05-21 14:44:19 +02:00
ilgeco fe35b3ed43 Equivalent Class but broken 2026-05-21 14:43:59 +02:00
NiccoloN 90a9339686 better cmake to keep IDEs analyses happy
Validate Operations / validate-operations (push) Has been cancelled
2026-05-21 14:13:54 +02:00
NiccoloN a50e77ff38 refactorone
Validate Operations / validate-operations (push) Has been cancelled
2026-05-20 19:06:41 +02:00
NiccoloN f56c4159b5 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor 2026-05-19 15:01:26 +02:00
ilgeco 5637c861b4 Merge branch 'main' of chef.heaplab.deib.polimi.it:nnicolosi/Raptor
Validate Operations / validate-operations (push) Has been cancelled
2026-05-19 15:00:11 +02:00
ilgeco 94157a8404 Very big timeout 2026-05-19 14:53:34 +02:00
ilgeco 68a3521978 Perft topological fix 2026-05-19 14:52:54 +02:00
187 changed files with 13264 additions and 11441 deletions
+2 -5
View File
@@ -4,14 +4,11 @@
.claude
.codex
AGENTS.md
CMakeUserPresets.json
build
build_release
cmake-build-debug
cmake-build-release
build_*
compile.sh
pimcomp_utils/*
**/__*
+92
View File
@@ -0,0 +1,92 @@
- Always read the full README.md before doing anything.
- Build commands:
- `cmake --build ./build_release`
- `cmake --build ./build_debug`
- Never use `ninja` directly: it bypasses cmake's configuration and invalidates the build cache.
- Always tries the release version build first and ask before building with the debug version
# Code changes
- Keep changes minimal and localized to the relevant parts of the code.
- Preserve the existing naming conventions and coding style used in the surrounding code.
- Keep code easy to read, well organized, and suitable for future extensibility. A function must not be longer than
200/250 lines for readability and cognitive complexity.
- Prefer clear naming and structure over comments. Add comments only when they materially improve clarity.
- Do not rename symbols, move files, or restructure modules unless that is necessary for the requested change.
# Working style
- Infer style and conventions from the existing code before introducing new patterns.
- When several implementation options are possible, prefer the simplest one that fits the current architecture and
minimizes churn.
- Avoid broad refactors unless I explicitly ask for them.
# Responses
- When showing code in chat, make it easy to copy-paste into the codebase.
- Keep outputs focused on the changed parts.
- At the end of the response, briefly list any bad practices, mistakes, or cleaner alternatives you noticed, separate
from the main solution.
# Guidelines
## 1. Think Before Coding
**Don't assume. Don't hide confusion. Surface tradeoffs.**
Before implementing:
- State your assumptions explicitly. If uncertain, ask.
- If multiple interpretations exist, present them - don't pick silently.
- If a simpler approach exists, say so. Push back when warranted.
- If something is unclear, stop. Name what's confusing. Ask.
## 2. Simplicity First
**Minimum code that solves the problem. Nothing speculative.**
- No features beyond what was asked.
- No error handling for impossible scenarios.
- If you write 200 lines and it could be 50, rewrite it.
Ask yourself: "Would a senior engineer say this is overcomplicated?" If yes, simplify.
## 3. Surgical Changes
**Touch only what you must. Clean up only your own mess.**
When editing existing code:
- Don't "improve" adjacent code, comments, or formatting.
- Don't refactor things that aren't broken.
- Match existing style, even if you'd do it differently.
- If you notice unrelated dead code, mention it - don't delete it.
When your changes create orphans:
- Remove imports/variables/functions that YOUR changes made unused.
- Don't remove pre-existing dead code unless asked, but mention it.
The test: Every changed line should trace directly to the user's request.
## 4. Goal-Driven Execution
**Define success criteria. Loop until verified.**
Transform tasks into verifiable goals:
- "Add validation" → "Write tests for invalid inputs, then make them pass"
- "Fix the bug" → "Write a test that reproduces it, then make it pass"
- "Refactor X" → "Ensure tests pass before and after"
For multi-step tasks, state a brief plan:
```
1. [Step] → verify: [check]
2. [Step] → verify: [check]
3. [Step] → verify: [check]
```
Strong success criteria let you loop independently. Weak criteria ("make it work") require constant clarification.
---
+85 -17
View File
@@ -3,31 +3,99 @@ cmake_minimum_required(VERSION 3.20.0)
project(raptor)
# Add symlink to PIM as accelerator in onnx-mlir
function(raptor_ensure_symlink link_path target_path)
get_filename_component(link_parent "${link_path}" DIRECTORY)
# Materialize a CMake shim directory
function(raptor_write_external_cmake_shim shim_dir external_source_dir description)
get_filename_component(real_external_source_dir "${external_source_dir}" REALPATH)
file(RELATIVE_PATH relative_external_source_dir "${shim_dir}" "${real_external_source_dir}")
if(NOT EXISTS "${link_parent}")
message(FATAL_ERROR "Directory not found: ${link_parent}")
endif()
if (NOT EXISTS "${real_external_source_dir}/CMakeLists.txt")
message(FATAL_ERROR
"External CMake source directory not found or missing CMakeLists.txt:\n"
" ${real_external_source_dir}"
)
endif ()
if (IS_SYMLINK "${shim_dir}")
message(STATUS "Removing old full-directory symlink: ${shim_dir}")
file(REMOVE "${shim_dir}")
endif ()
if (EXISTS "${shim_dir}" AND NOT IS_DIRECTORY "${shim_dir}")
message(FATAL_ERROR "Expected directory or absent path, got file: ${shim_dir}")
endif ()
file(MAKE_DIRECTORY "${shim_dir}")
set(shim_file "${shim_dir}/CMakeLists.txt")
set(shim_contents
"get_filename_component(raptor_external_source_dir
\"\${CMAKE_CURRENT_LIST_DIR}/${relative_external_source_dir}\"
REALPATH
)
add_subdirectory(
\"\${raptor_external_source_dir}\"
\"\${CMAKE_CURRENT_BINARY_DIR}/raptor-external\"
)
if (DEFINED PIM_ENABLED)
set(PIM_ENABLED \"\${PIM_ENABLED}\" PARENT_SCOPE)
endif ()
"
)
if (EXISTS "${shim_file}")
file(READ "${shim_file}" old_contents)
else ()
set(old_contents "")
endif ()
if (NOT old_contents STREQUAL shim_contents)
file(WRITE "${shim_file}" "${shim_contents}")
message(STATUS "Wrote CMake shim for ${description}: ${shim_file}")
else ()
message(STATUS "CMake shim already up to date for ${description}")
endif ()
# Mirror the external tree's first-level entries into the shim directory
# so legacy includes like src/Accelerators/PIM/Compiler/... keep working.
file(GLOB children RELATIVE "${real_external_source_dir}" "${real_external_source_dir}/*")
foreach (child IN LISTS children)
if (child STREQUAL "CMakeLists.txt")
continue()
endif ()
set(real_child "${real_external_source_dir}/${child}")
set(shim_child "${shim_dir}/${child}")
if (IS_SYMLINK "${shim_child}")
file(READ_SYMLINK "${shim_child}" existing_link_target)
if (existing_link_target STREQUAL real_child)
continue()
endif ()
file(REMOVE_RECURSE "${shim_child}")
elseif (EXISTS "${shim_child}")
# Do not delete real files/directories. This protects the generated shim.
continue()
endif ()
if(NOT EXISTS "${link_path}")
message(STATUS "Creating symlink ${link_path} -> ${target_path}")
file(CREATE_LINK
"${target_path}"
"${link_path}"
"${real_child}"
"${shim_child}"
SYMBOLIC
)
endif()
endforeach ()
endfunction()
raptor_ensure_symlink(
raptor_write_external_cmake_shim(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
"PIM accelerator"
)
raptor_ensure_symlink(
raptor_write_external_cmake_shim(
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
"PIM accelerator tests"
)
# Patch onnx-mlir sources for PIM accelerator support.
@@ -38,21 +106,21 @@ function(raptor_apply_patch file_path anchor replacement description)
# Already applied replacement text is present
string(FIND "${contents}" "${replacement}" already_applied_pos)
if(NOT already_applied_pos EQUAL -1)
if (NOT already_applied_pos EQUAL -1)
message(STATUS "Patch already applied: ${description}")
return()
endif()
endif ()
# Anchor must exist for the patch to be applicable
string(FIND "${contents}" "${anchor}" anchor_pos)
if(anchor_pos EQUAL -1)
if (anchor_pos EQUAL -1)
message(FATAL_ERROR
"Patch anchor not found onnx-mlir may have changed.\n"
" Patch : ${description}\n"
" File : ${file_path}\n"
" Anchor: ${anchor}"
)
endif()
endif ()
string(REPLACE "${anchor}" "${replacement}" patched "${contents}")
file(WRITE "${file_path}" "${patched}")
+250 -155
View File
@@ -1,226 +1,321 @@
# Raptor
Raptor is a domain-specific MLIR compiler for neural networks (ONNX format)
targeting in-memory computing / processing-in-memory (PIM) architectures.
It progressively lowers ONNX-MLIR through a set of MLIR dialects down to
target-specific artifacts (currently JSON code for the `pimsim-nn` simulator).
Raptor is a domain-specific MLIR compiler for neural networks in ONNX format,
targeting in-memory computing / processing-in-memory (PIM) architectures. It
extends ONNX-MLIR with a PIM accelerator and progressively lowers ONNX-MLIR
through custom MLIR dialects to simulator artifacts.
The current target is the PIM simulator stack under `backend-simulators/pim`.
Raptor emits binary per-core `.pim` instruction files by default, plus
`memory.bin`, `config.json`, and weight binaries. It can also emit per-core JSON
instruction files with `--pim-emit-json`.
## Overview
PIM architectures perform most of the computation directly in memory.
Raptor's first supported target is `pimsim-nn`, which simulates a chip with:
- a shared host memory,
- a number of cores that do most of the computation directly in their memory
(vector ops, vmm/mvm on ReRAM crossbars),
- no branching instructions (branchless architecture) and no hardware loop
support — any repeated work (e.g. convolutions) must be unrolled into
explicit per-iteration instructions.
PIM architectures perform most computation directly in memory. The supported
target models a chip with:
- shared host memory,
- multiple PIM cores,
- ReRAM crossbars for vector-matrix / matrix-vector work,
- explicit communication between cores,
- no hardware branch or loop support in emitted simulator code.
Because of this, the amount of emitted instructions explodes quickly and the
compiler must optimize aggressively at every stage to keep compilation
tractable.
A second target, `PulPim`, is planned for an accelerator with RISC-V cores
each carrying its own in-memory computing unit and crossbars. It will live in
a dedicated dialect (future work).
Because repeated work such as convolutions is eventually made explicit, emitted
instruction counts can grow quickly. Most compiler work therefore focuses on
lowering, scheduling, memory layout, and code-generation optimizations.
### Targets and simulators
`pimsim-nn` (under `backend-simulators/pim/pimsim-nn`) is used for
**performance** estimates (latency, energy), but does not functionally execute
the JSON code it consumes. To validate the numerical correctness of the JSON
code produced by Raptor (or, for comparison, by the `pimcomp` compiler), we use
a Rust simulator we maintain in-tree at
`backend-simulators/pim/pim-simulator`.
- `backend-simulators/pim/pim-simulator` is the in-tree Rust functional
simulator used by validation. It reads Raptor's `pim/` artifact directory and
compares simulator output against native ONNX-MLIR execution.
- `backend-simulators/pim/pimsim-nn` is the performance simulator submodule.
The helper scripts in `pimcomp_utils/` are for comparison with PIMCOMP-NN and
contain local paths; treat them as local utilities, not portable workflows.
## Compilation pipeline
The PIM-related sources live under `src/PIM` and the tests under `test/PIM`.
When working on this codebase, most changes should stay confined to those
trees (you only need to look outside, e.g. at `onnx-mlir` or `llvm`, for
framework-level details).
The PIM sources live under `src/PIM` and tests under `test/PIM`. CMake exposes
them to ONNX-MLIR through generated shim directories under
`onnx-mlir/src/Accelerators/PIM` and `onnx-mlir/test/accelerators/PIM`.
High-level lowering flow:
```
ONNX-MLIR ──► Spatial ──► Pim (tensor) ──► Pim (bufferized) ──► PIM JSON
ONNX-MLIR -> Spatial -> Pim (tensor) -> Pim (bufferized) -> PIM artifacts
```
1. **ONNX Spatial** (`src/PIM/Conversion/ONNXToSpatial`).
Lowers ONNX ops into the `spat` dialect (`src/PIM/Dialect/Spatial`).
Spatial models a high-level spatial in-memory accelerator: vmm/mvm
operations are accelerated by storing a constant RHS matrix into a
crossbar. Crossbars cannot be re-programmed during execution, have a
limited fixed size, and there is a limited number of them per core.
Conversion patterns are split by op family under
`Conversion/ONNXToSpatial/Patterns/{Math,NN,Tensor}` (Conv, Gemm, MatMul,
Elementwise, ReduceMean, Pool, Relu, Sigmoid, Softmax, Concat, Gather,
Reshape, Resize, Split).
1. **ONNX -> Spatial** (`src/PIM/Conversion/ONNXToSpatial`).
Lowers supported ONNX ops into the `spat` dialect
(`src/PIM/Dialect/Spatial`). Conversion patterns are split by op family under
`Patterns/{Math,NN,Tensor}` and currently cover Conv, Gemm, MatMul,
elementwise Add/Mul/Div, ReduceMean, pooling, Relu, Sigmoid, Softmax,
Concat, Gather, Reshape, Resize, and Split.
2. **Spatial → Pim** (`src/PIM/Conversion/SpatialToPim`).
Lowers Spatial to the `pim` dialect (`src/PIM/Dialect/Pim`), which
materializes PIM cores (`pim.core`), inter-core communication
(`pim.send` / `pim.receive`), halts, and crossbar-level operations.
2. **Merge compute nodes**
(`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`).
Builds a compute graph, schedules it with the PEFT scheduler, and materializes
the merge schedule into Spatial IR. Supporting scheduling code lives under
`MergeComputeNodes/Scheduling`.
3. **Merge compute nodes** (`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`).
A DCP-inspired heuristic (Dynamic Critical Path — see the original
scheduling paper by Kwok & Ahmad,
[DCP-eScience2007](https://clouds.cis.unimelb.edu.au/papers/DCP-eScience2007.pdf))
that coarsens the virtual node graph and decides how to group compute
nodes onto cores. Our implementation is only DCP-*inspired*: it is a
heuristic with different assumptions from the paper (different cost
model, constraints from crossbar capacity / core resources, and a
windowed coarsening loop instead of full-graph reprioritization). The
`dcp-critical-window-size` option controls how many lowest-slack virtual
nodes each coarsening iteration considers (0 = legacy full-graph
analysis). Related sources: `DCPGraph/DCPAnalysis.cpp`, `Graph.cpp/.hpp`,
`MergeComputeNodesPass.cpp`.
3. **Spatial -> Pim** (`src/PIM/Conversion/SpatialToPim`).
Lowers Spatial operations to the `pim` dialect (`src/PIM/Dialect/Pim`),
including `pim.core`, `pim.core_batch`, communication, tensor packing, global
tensor materialization, and return-path normalization.
4. **Bufferization** (`src/PIM/Dialect/Pim/Transforms/Bufferization`).
Converts tensor-semantics PIM IR into memref-semantics PIM IR using the
standard MLIR `BufferizableOpInterface` machinery
(`OpBufferizationInterfaces.*`, `PimBufferization.td`).
Converts tensor-semantics PIM IR into memref-semantics PIM IR using MLIR's
bufferization interfaces.
5. **Static memory coalescing** (`src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing`).
Conservatively reuses same-typed local memref allocations inside PIM cores
after bufferization and before code generation.
5. **Static memory coalescing**
(`src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing`).
Reuses compatible local memref allocations inside PIM cores before codegen.
6. **PIM code generation** (`src/PIM/Pass/PimCodegen`):
- `HostConstantFolding` — folds host-side constants.
- `MaterializeHostConstantsPass` materializes the remaining host
constants for emission.
- `VerificationPass` — checks invariants before emission.
- `EmitPimJsonPass` — emits the final PIM JSON consumed by `pimsim-nn`
and `pim-simulator`.
6. **PIM code generation** (`src/PIM/Pass/PimCodegen` and
`src/PIM/Compiler`).
Folds host constants, materializes remaining host constants, verifies PIM IR,
emits `.pim` core files, writes weights, and writes `memory.bin` /
`config.json`.
Supporting pieces:
- `src/PIM/Compiler` — PIM-specific compiler options (crossbar size/count,
core count, DCP window, experimental conv impl, concat error handling, …)
and `PimCodeGen` entry points.
- `src/PIM/Common` — shared utilities (`PimCommon`, `LabeledList`).
- `src/PIM/Pass` — auxiliary passes (`MessagePass`, `CountInstructionPass`)
and the `PIMPasses.h` registry used by `PimAccelerator`.
- `src/PIM/PimAccelerator.{cpp,hpp}` — accelerator entry point: registers
dialects, passes, and plugs Raptor into the ONNX-MLIR driver.
- `src/PIM/Common` - shared IR, filesystem, diagnostics, reports, and utility
helpers.
- `src/PIM/Compiler` - PIM compiler options, memory/address planning, binary
instruction format, artifact writing, weight emission, and codegen entry
points.
- `src/PIM/Conversion/SpatialToGraphviz` - optional Spatial graphviz conversion
pass.
- `src/PIM/Pass` - pass registration and auxiliary passes.
- `src/PIM/PimAccelerator.{cpp,hpp}` - ONNX-MLIR accelerator entry point.
## Key compiler options
Pass these on the `onnx-mlir` command line when compiling for PIM:
Pass these to `onnx-mlir` when compiling for PIM:
- `--maccel=PIM` select the PIM accelerator.
- `--EmitSpatial` / `--EmitPim` / `--EmitPimBufferized` / `--EmitPimCodegen`
stop the pipeline at the requested stage (default: `EmitPimCodegen`).
- `--pim-only-codegen` — assume the input is already bufferized PIM IR and
run only the codegen tail.
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
per-core count.
- `--core-count=<N>` — number of cores. Required for PIM compilation.
- `--pim-merge-scheduler={peft,dcp}` — scheduler used by the Spatial
merge-compute-nodes pass (default: `peft`).
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
- `--use-experimental-conv-impl` alternative convolution lowering.
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
- `--maccel=PIM` - select the PIM accelerator.
- `--EmitSpatial`, `--EmitPim`, `--EmitPimBufferized`,
`--EmitPimCodegen` - stop the PIM pipeline at the requested stage. The PIM
default is `--EmitPimCodegen`.
- `--core-count=<N>` - required positive core count for PIM compilation.
- `--crossbar-size=<N>` - crossbar width/height. Default in code is `2`.
- `--crossbar-count=<N>` - crossbars per core. Default in code is `256`.
- `--pim-merge-scheduler=peft` - merge scheduler. `peft` is the only accepted
value in the current code.
- `--pim-only-codegen` - assume input is already bufferized PIM IR and only run
the codegen tail.
- `--pim-emit-json` - also emit `core_*.json` instruction files alongside
`core_*.pim`.
- `--use-experimental-conv-impl` - use the alternate convolution lowering.
- `--ignore-concat-error` - soft-fail a ConcatOp corner case.
Example:
```bash
./build_release/Release/bin/onnx-mlir model.onnx -o /tmp/raptor/model \
--maccel=PIM --EmitPimCodegen \
--crossbar-size=2048 --crossbar-count=256 --core-count=1000
```
This writes PIM artifacts under `/tmp/raptor/pim/`.
## Validation
Functional validation lives in `validation/` and drives the Rust
`pim-simulator` to compare Raptor's output against a reference.
Functional validation lives in `validation/`. It compiles ONNX models, builds a
native ONNX-MLIR reference runner, generates random inputs, runs Raptor, runs
the Rust PIM simulator, and compares outputs.
Per-operation validation (from `validation/`):
Python dependencies used by the validation scripts are `numpy`, `onnx`, and
`colorama`. The simulator requires the Rust toolchain.
```
validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
--onnx-include-dir ../onnx-mlir/include \
Per-operation validation from the repository root:
```bash
python3 validation/validate.py \
--raptor-path build_release/Release/bin/onnx-mlir \
--onnx-include-dir onnx-mlir/include \
--core-count 1000
```
End-to-end network validation (example: first 4 layers of YOLOv11n):
Validate one network or a subset by pointing `--operations-dir` at any directory
containing `.onnx` files:
```
validate.py \
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
--onnx-include-dir ../onnx-mlir/include \
--operations-dir ./networks/yolo11n/depth_04 \
```bash
python3 validation/validate.py \
--raptor-path build_release/Release/bin/onnx-mlir \
--onnx-include-dir onnx-mlir/include \
--operations-dir validation/networks/yolo11n/depth_04 \
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
```
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
Available operations under `validation/operations/`: `add`, `conv`, `div`,
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
`sigmoid`, `softmax`, `split`.
Useful validation options:
- `--simulator-dir <path>` - override the auto-detected
`backend-simulators/pim/pim-simulator` path.
- `--threshold <float>` - maximum allowed per-element output difference.
- `--seed <int>` - RNG seed for generated inputs.
- `--command-timeout-seconds <float>` - timeout for compiler, runner, and
simulator subprocesses.
- `--verbose` - print subprocess logs and average PIM pass timings.
- `--clean` - remove generated validation artifacts and exit.
## Rebuilding
Each validation run writes artifacts in the model workspace, for example under
`validation/operations/gemm/small/`:
- `inputs/` - generated input CSV files.
- `outputs/` - native ONNX-MLIR reference outputs.
- `raptor/` - compiler artifacts, including `*.onnx.mlir`, dialect dumps under
`dialects/`, reports under `reports/`, and final PIM artifacts under `pim/`.
- `runner/` - generated reference runner source, build tree, and shared library.
- `simulation/out.bin` - raw simulator output used for comparison.
Release build (fast):
The compiler currently dumps dialect snapshots such as `spatial0.mlir`,
`spatial1_dcp_merged.mlir`, `pim0.mlir`, `pim1_buff.mlir`,
`pim2_coalesced.mlir`, `pim3_folded.mlir`, and
`pim4_materialized.mlir` when an output directory is available.
```
cmake --build /home/nico/raptor/raptor/cmake-build-release --target onnx-mlir -j 30
To rerun the simulator manually with tracing after validation has produced a
`raptor/pim/` directory:
```bash
cd backend-simulators/pim/pim-simulator
cargo run --no-default-features --features tracing --release \
--package pim-simulator --bin pim-simulator -- \
-f /path/to/workspace/raptor/pim \
-o /path/to/workspace/simulation/out.bin \
-d <addr0>,<size0>,<addr1>,<size1>,...
```
A slower debug build is also available — configure it the same way but with
`-DCMAKE_BUILD_TYPE=Debug` (see installation instructions below).
With `--features tracing`, the simulator writes per-core traces as
`TraceCore0`, `TraceCore1`, ... next to `out.bin`. The validator normally
computes the `-d` ranges from `raptor/pim/config.json` and model output shapes.
Available validation networks under `validation/networks/`: `vgg16`,
`yolo11n`, `yolo11nv2`.
Available operation suites under `validation/operations/`: `add`, `concat`,
`conv`, `div`, `gather`, `gemm`, `gemv`, `matmul`, `mul`, `pool`,
`reduce_mean`, `relu`, `reshape`, `resize`, `sigmoid`, `softmax`, `split`.
Generated operation tests can be regenerated with:
```bash
python3 validation/operations/gen_tests.py
```
## Build
Initialize submodules first:
```bash
git submodule update --init --recursive
```
The project follows ONNX-MLIR's build requirements. The CI workflow documents
the currently used versions and setup:
- CMake 4.3.0 in CI,
- LLVM/MLIR checked out under `onnx-mlir/llvm-project`,
- Protobuf `v34.0`,
- Rust stable for `pim-simulator`,
- Python packages `numpy`, `onnx`, `colorama` for validation.
### Protobuf
Use the following commands to install protobuf:
```
Install Protobuf if your system does not already provide a compatible version:
```bash
git clone --depth 1 --branch v34.0 https://github.com/protocolbuffers/protobuf
cd protobuf
mkdir build
cd build
cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release
ninja
sudo ninja install
cmake -S protobuf -B protobuf/build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-Dprotobuf_BUILD_TESTS=OFF
cmake --build protobuf/build
sudo cmake --install protobuf/build
```
You can now remove the protobuf repo directory with:
```
cd ../..
You can then remove the temporary checkout:
```bash
rm -rf protobuf
```
### Mlir
### MLIR
Follow the first part of instructions [here](onnx-mlir/docs/BuildOnLinuxOSX.md) to build mlir.
Follow the ONNX-MLIR instructions in
`onnx-mlir/docs/BuildOnLinuxOSX.md` to build LLVM/MLIR. The local Raptor build
expects `MLIR_DIR` to point at the MLIR CMake package, for example:
Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor
Moreover, if compiling with build type debug, it is also suggested to use
mold as linker (you will need to install it if you don't have it already)
to reduce memory usage during linking. You can use it by setting the options:
```
-DLLVM_USE_LINKER=mold
```bash
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build_release/lib/cmake/mlir
```
If your LLVM build directory is named `build` instead of `build_release`, adjust
the path accordingly.
### Raptor
Use the following commands to build Raptor.
Configure a release build:
Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor.
Also in this case, it is suggested to use mold as linker to reduce link time and memory usage,
setting the options:
```
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_MODULE_LINKER_FLAGS="-fuse-ld=mold"
```
```
git submodule update --init --recursive
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build/lib/cmake/mlir
mkdir build && cd build
cmake .. -G Ninja \
```bash
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build_release/lib/cmake/mlir
cmake -S . -B build_release -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DONNX_MLIR_ACCELERATORS=PIM \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_DIR=${MLIR_DIR}
cmake --build .
```
If the build fails because of protobuf missing uint definitions,
just patch the problematic files by adding ```#include <cstdint>``` to their includes.
Configure a debug build similarly:
```bash
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build_debug/lib/cmake/mlir
cmake -S . -B build_debug -G Ninja \
-DCMAKE_BUILD_TYPE=Debug \
-DONNX_MLIR_ACCELERATORS=PIM \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_DIR=${MLIR_DIR}
```
For debug development, using `mold` can reduce link time and memory use:
```bash
cmake -S . -B build_debug -G Ninja \
-DCMAKE_BUILD_TYPE=Debug \
-DONNX_MLIR_ACCELERATORS=PIM \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DMLIR_DIR=${MLIR_DIR} \
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold" \
-DCMAKE_MODULE_LINKER_FLAGS="-fuse-ld=mold"
```
Build the compiler with CMake:
```bash
cmake --build ./build_release
cmake --build ./build_debug
```
Do not invoke `ninja` directly for this project; use `cmake --build` so CMake's
configuration and generated shims stay consistent.
If a build fails because Protobuf headers are missing fixed-width integer
definitions, patch the affected Protobuf-generated files by adding
`#include <cstdint>`.
## Tests
The Rust simulator has its own tests:
```bash
cd backend-simulators/pim/pim-simulator
cargo test
```
## Repository Layout
- `src/PIM/` - PIM accelerator implementation.
- `test/PIM/` - PIM C++ unit tests.
- `validation/` - functional validation scripts, ONNX operation tests, network
slices, and pimsim config generation.
- `backend-simulators/pim/pim-simulator/` - in-tree Rust functional simulator.
- `backend-simulators/pim/pimsim-nn/` - performance simulator submodule.
- `pimcomp_utils/` - local comparison helpers for PIMCOMP-NN.
- `.github/actions/` and `.github/workflows/validate_operations.yml` - CI setup
for MLIR/Protobuf caching, building Raptor, and validation.
@@ -43,7 +43,7 @@ struct Args {
/// Comma separated list of (address,size) for memory output dump
#[arg(short, long, value_delimiter = ',', num_args = 1.., value_name = "ADDR,SIZE")]
dump: Vec<i32>,
dump: Vec<usize>,
}
fn main() -> Result<()> {
@@ -168,7 +168,7 @@ fn get_crossbars(config: &Value, args: &Args) -> anyhow::Result<HashMap<String,
}
fn dump_memory(mut executor: pimcore::Executable, args: &Args) -> Result<()> {
let dumps: Vec<(i32, i32)> = args
let dumps: Vec<(usize, usize)> = args
.dump
.chunks_exact(2)
.map(|chunk| (chunk[0], chunk[1]))
@@ -1,3 +1,4 @@
use crate::utility::AddressArg;
use std::{collections::HashMap, fmt::Debug};
use anyhow::{Context, Result, ensure};
@@ -9,6 +10,7 @@ use crate::{
pub mod crossbar;
#[derive(Debug, Clone)]
pub struct CPU<'a> {
cores: Box<[Core<'a>]>,
@@ -91,30 +93,26 @@ impl<'a> Core<'a> {
self.memory.execute_load()
}
pub fn execute_store<T>(&mut self, address: impl TryToUsize, element: &[T]) -> Result<()>
pub fn execute_store<T>(&mut self, address: impl AddressArg, element: &[T]) -> Result<()>
where
T: MemoryStorable,
{
let address = address.try_into().context("address can not be negative")?;
let address = address.to_address_usize()?;
self.memory.execute_store(address, element)
}
pub fn reserve_load(
&mut self,
address: impl TryToUsize,
address: impl AddressArg,
size: impl TryToUsize,
) -> Result<&mut CoreMemory> {
let address = address.try_into().context("address can not be negative")?;
let address = address.to_address_usize()?;
let size = size.try_into().context("size can not be negative")?;
self.memory.reserve_load(address, size)
}
pub fn set_register(&mut self, index: impl TryToUsize, value: i32) {
let index = index.try_into().expect("index can not be negative");
assert!(
value >= 0,
"Register cannot be negative if happens remove this and go check where it's used as usize"
);
self.registers[index] = value;
}
@@ -123,11 +121,11 @@ impl<'a> Core<'a> {
self.registers[index]
}
pub fn load<T>(&mut self, address: impl TryToUsize, size: impl TryToUsize) -> Result<Vec<&[T]>>
pub fn load<T>(&mut self, address: impl AddressArg, size: impl TryToUsize) -> Result<Vec<&[T]>>
where
T: MemoryStorable,
{
let address = address.try_into().context("address can not be negative")?;
let address = address.to_address_usize()?;
let size = size.try_into().context("size can not be negative")?;
self.memory.load(address, size)
}
@@ -141,8 +139,8 @@ impl<'a> Core<'a> {
(memory, crossbars)
}
pub fn memset(&mut self, address: impl TryToUsize, size: impl TryToUsize, val: u8) -> Result<()> {
let address = address.try_into().context("address can not be negative")?;
pub fn memset(&mut self, address: impl AddressArg, size: impl TryToUsize, val: u8) -> Result<()> {
let address = address.to_address_usize()?;
let size = size.try_into().context("size can not be negative")?;
self.memory.memset(address, size, val)
}
@@ -299,10 +299,11 @@ fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockIn
if in_path.contains(&waiting_for) {
let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap();
let cycle = &path[cycle_start..];
let format_core = |core: &i32| (core - 1).to_string();
let cycle_str = cycle
.iter()
.map(|c| c.to_string())
.map(format_core)
.collect::<Vec<_>>()
.join(" -> ");
@@ -311,19 +312,19 @@ fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockIn
.copied()
.chain(std::iter::once(waiting_for))
.collect::<Vec<_>>();
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for - 1);
let states_msg = cycle
.iter()
.filter_map(|core| {
states.get(core).map(|state| match state {
CoreState::SendingTo(target, size) => {
format!("core {} send {}B -> {}", core, size, target)
format!("core {} send {}B -> {}", core - 1, size, target - 1)
}
CoreState::ReceivingFrom(source, size) => {
format!("core {} recv {}B <- {}", core, size, source)
format!("core {} recv {}B <- {}", core - 1, size, source - 1)
}
CoreState::Working => format!("core {} working", core),
CoreState::Halted => format!("core {} halted", core),
CoreState::Working => format!("core {} working", core - 1),
CoreState::Halted => format!("core {} halted", core - 1),
})
})
.collect::<Vec<_>>()
@@ -1,7 +1,45 @@
use anyhow::{Result,Context};
use std::{fmt::Debug, mem::transmute};
use crate::memory_manager::type_traits::TryToUsize;
pub trait AddressArg {
fn to_address_usize(self) -> Result<usize>;
}
impl AddressArg for usize {
fn to_address_usize(self) -> Result<usize> {
Ok(self)
}
}
impl AddressArg for u32 {
fn to_address_usize(self) -> Result<usize> {
Ok(self as usize)
}
}
impl AddressArg for u64 {
fn to_address_usize(self) -> Result<usize> {
usize::try_from(self).context("address does not fit in usize")
}
}
impl AddressArg for i32 {
fn to_address_usize(self) -> Result<usize> {
Ok(self as u32 as usize)
}
}
impl AddressArg for i64 {
fn to_address_usize(self) -> Result<usize> {
usize::try_from(self).context("address can not be negative")
}
}
fn address_to_usize(address: i32) -> usize {
address as u32 as usize
}
fn add_offset_impl(address: usize, offset_select : i32, offset_value : i32, id:i32) -> usize{
assert!(offset_select == 1 || offset_select == 2 || offset_select == 4 || offset_value == 0, "offset_select not a bit field");
@@ -14,21 +52,21 @@ fn add_offset_impl(address: usize, offset_select : i32, offset_value : i32, id:i
}
pub fn add_offset_rd(address: impl TryToUsize, offset_select : i32, offset_value : i32) -> usize
pub fn add_offset_rd(address: i32, offset_select : i32, offset_value : i32) -> usize
{
let address = address.try_into().expect("address can not be negative");
let address = address_to_usize(address);
add_offset_impl(address, offset_select, offset_value, 4)
}
pub fn add_offset_r1(address: impl TryToUsize, offset_select : i32, offset_value : i32) -> usize
pub fn add_offset_r1(address: i32, offset_select : i32, offset_value : i32) -> usize
{
let address = address.try_into().expect("address can not be negative");
let address = address_to_usize(address);
add_offset_impl(address, offset_select, offset_value, 1)
}
pub fn add_offset_r2(address: impl TryToUsize, offset_select : i32, offset_value : i32) -> usize
pub fn add_offset_r2(address: i32, offset_select : i32, offset_value : i32) -> usize
{
let address = address.try_into().expect("address can not be negative");
let address = address_to_usize(address);
add_offset_impl(address, offset_select, offset_value, 2)
}
+53
View File
@@ -10,6 +10,56 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
set(PIM_GENERATED_PATH_SHIM_TARGET "")
get_filename_component(PIM_BIN_ROOT_NAME "${PIM_BIN_ROOT}" NAME)
if (PIM_BIN_ROOT_NAME STREQUAL "raptor-external")
get_filename_component(PIM_GENERATED_PATH_SHIM_ROOT "${PIM_BIN_ROOT}" DIRECTORY)
set(PIM_GENERATED_PATH_SHIM_OUTPUTS)
function(add_pim_generated_path_shim relative_path)
set(real_file "${PIM_BIN_ROOT}/${relative_path}")
set(shim_file "${PIM_GENERATED_PATH_SHIM_ROOT}/${relative_path}")
get_filename_component(shim_dir "${shim_file}" DIRECTORY)
add_custom_command(
OUTPUT "${shim_file}"
DEPENDS "${real_file}"
COMMAND "${CMAKE_COMMAND}" -E make_directory "${shim_dir}"
COMMAND "${CMAKE_COMMAND}" -E rm -f "${shim_file}"
COMMAND "${CMAKE_COMMAND}" -E create_symlink "${real_file}" "${shim_file}"
VERBATIM
)
list(APPEND PIM_GENERATED_PATH_SHIM_OUTPUTS "${shim_file}")
set(PIM_GENERATED_PATH_SHIM_OUTPUTS "${PIM_GENERATED_PATH_SHIM_OUTPUTS}" PARENT_SCOPE)
endfunction()
file(GLOB_RECURSE pim_generated_path_scan_sources
CONFIGURE_DEPENDS
"${PIM_SRC_ROOT}/*.cpp"
"${PIM_SRC_ROOT}/*.hpp"
)
set(pim_generated_path_shims)
foreach (source_file IN LISTS pim_generated_path_scan_sources)
file(READ "${source_file}" source_contents)
string(REGEX MATCHALL "#include \"src/Accelerators/PIM/[^\"]+\\.inc\"" source_inc_matches "${source_contents}")
foreach (inc_match IN LISTS source_inc_matches)
string(REGEX REPLACE "^#include \"src/Accelerators/PIM/(.+)\"$" "\\1" relative_inc_path "${inc_match}")
list(APPEND pim_generated_path_shims "${relative_inc_path}")
endforeach ()
endforeach ()
list(REMOVE_DUPLICATES pim_generated_path_shims)
foreach (relative_inc_path IN LISTS pim_generated_path_shims)
add_pim_generated_path_shim("${relative_inc_path}")
endforeach ()
add_custom_target(OMPimGeneratedPathShims DEPENDS ${PIM_GENERATED_PATH_SHIM_OUTPUTS})
set(PIM_GENERATED_PATH_SHIM_TARGET OMPimGeneratedPathShims)
endif ()
set(PIM_PUBLIC_INCLUDE_DIRS
${ONNX_MLIR_SRC_ROOT}/include
${ONNX_MLIR_SRC_ROOT}
@@ -37,6 +87,9 @@ set(PIM_GENERATED_INCLUDE_DIRS
function(add_pim_library name)
add_onnx_mlir_library(${name} STATIC ${ARGN})
if (PIM_GENERATED_PATH_SHIM_TARGET)
add_dependencies(${name} ${PIM_GENERATED_PATH_SHIM_TARGET})
endif ()
endfunction()
add_subdirectory(Dialect)
+4
View File
@@ -1,5 +1,8 @@
add_pim_library(OMPimCommon
IR/AffineUtils.cpp
IR/AddressAnalysis.cpp
IR/BatchCoreUtils.cpp
IR/ConstantUtils.cpp
IR/CoreBlockUtils.cpp
IR/EntryPointUtils.cpp
IR/ShapeUtils.cpp
@@ -16,6 +19,7 @@ add_pim_library(OMPimCommon
${PIM_PUBLIC_INCLUDE_DIRS}
LINK_LIBS PUBLIC
MLIRLinalgDialect
onnx
SpatialOps
PimOps
+547 -7
View File
@@ -1,7 +1,11 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include <limits>
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -28,6 +32,14 @@ mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledg
return value;
}
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value);
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value);
template <typename... Args>
CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) {
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::forward<Args>(args)...));
}
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
@@ -55,6 +67,288 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
}
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
static llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefStrides(mlir::MemRefType type) {
llvm::SmallVector<int64_t> strides;
int64_t offset = 0;
if (failed(type.getStridesAndOffset(strides, offset)))
return mlir::failure();
if (llvm::any_of(strides, mlir::ShapedType::isDynamic))
return mlir::failure();
return strides;
}
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
const StaticValueKnowledge* knowledge) {
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
if (!getGlobalOp)
return mlir::failure();
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
if (!denseAttr || !globalType || !globalType.hasStaticShape())
return mlir::failure();
auto elementType = denseAttr.getElementType();
if (!elementType.isIndex() && !elementType.isInteger())
return mlir::failure();
llvm::SmallVector<int64_t> indices;
indices.reserve(loadOp.getIndices().size());
for (mlir::Value index : loadOp.getIndices()) {
auto resolvedIndex = resolveIndexValueImpl(index, knowledge);
if (failed(resolvedIndex))
return mlir::failure();
indices.push_back(*resolvedIndex);
}
if (indices.size() != static_cast<size_t>(globalType.getRank()))
return mlir::failure();
auto strides = computeRowMajorStrides(globalType.getShape());
int64_t linearIndex = linearizeIndex(indices, strides);
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
return mlir::failure();
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
}
static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) {
switch (predicate) {
case mlir::arith::CmpIPredicate::eq: return lhs == rhs;
case mlir::arith::CmpIPredicate::ne: return lhs != rhs;
case mlir::arith::CmpIPredicate::slt: return lhs < rhs;
case mlir::arith::CmpIPredicate::sle: return lhs <= rhs;
case mlir::arith::CmpIPredicate::sgt: return lhs > rhs;
case mlir::arith::CmpIPredicate::sge: return lhs >= rhs;
case mlir::arith::CmpIPredicate::ult: return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::ule: return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::ugt: return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
case mlir::arith::CmpIPredicate::uge: return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
}
llvm_unreachable("unknown cmpi predicate");
}
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr,
const StaticValueKnowledge& knowledge) {
if (!expr.node)
return mlir::failure();
switch (expr.node->kind) {
case CompiledIndexExprNode::Kind::Constant: return expr.node->constant;
case CompiledIndexExprNode::Kind::Symbol: {
auto value = resolveAlias(expr.node->symbol, &knowledge);
auto iter = knowledge.indexValues.find(value);
if (iter != knowledge.indexValues.end())
return iter->second;
return mlir::failure();
}
case CompiledIndexExprNode::Kind::Add:
case CompiledIndexExprNode::Kind::Sub:
case CompiledIndexExprNode::Kind::Mul:
case CompiledIndexExprNode::Kind::DivUI:
case CompiledIndexExprNode::Kind::DivSI:
case CompiledIndexExprNode::Kind::RemUI:
case CompiledIndexExprNode::Kind::RemSI:
case CompiledIndexExprNode::Kind::MinUI:
case CompiledIndexExprNode::Kind::CmpI: {
auto lhs = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
auto rhs = evaluateCompiledIndexExpr(expr.node->operands[1], knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
switch (expr.node->kind) {
case CompiledIndexExprNode::Kind::Add: return *lhs + *rhs;
case CompiledIndexExprNode::Kind::Sub: return *lhs - *rhs;
case CompiledIndexExprNode::Kind::Mul: return *lhs * *rhs;
case CompiledIndexExprNode::Kind::DivUI:
if (*rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
case CompiledIndexExprNode::Kind::DivSI:
if (*rhs == 0 || (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1))
return mlir::failure();
return *lhs / *rhs;
case CompiledIndexExprNode::Kind::RemUI:
if (*rhs == 0)
return mlir::failure();
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
case CompiledIndexExprNode::Kind::RemSI:
if (*rhs == 0)
return mlir::failure();
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
return 0;
return *lhs % *rhs;
case CompiledIndexExprNode::Kind::MinUI:
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
case CompiledIndexExprNode::Kind::CmpI: return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
default: llvm_unreachable("unexpected binary compiled index kind");
}
}
case CompiledIndexExprNode::Kind::Select: {
auto condition = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
if (failed(condition))
return mlir::failure();
return evaluateCompiledIndexExpr(*condition != 0 ? expr.node->operands[1] : expr.node->operands[2], knowledge);
}
case CompiledIndexExprNode::Kind::ConstantGlobalLoad: {
if (!expr.node->globalOp || !expr.node->globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*expr.node->globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(expr.node->globalOp.getType());
if (!denseAttr || !globalType)
return mlir::failure();
llvm::SmallVector<int64_t> indices;
indices.reserve(expr.node->operands.size());
for (const CompiledIndexExpr& operand : expr.node->operands) {
auto resolvedIndex = evaluateCompiledIndexExpr(operand, knowledge);
if (failed(resolvedIndex))
return mlir::failure();
indices.push_back(*resolvedIndex);
}
int64_t linearIndex = linearizeIndex(indices, expr.node->globalStrides);
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
return mlir::failure();
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
}
}
llvm_unreachable("unknown compiled index kind");
}
llvm::FailureOr<CompiledIndexExpr> compileConstantGlobalLoad(mlir::memref::LoadOp loadOp) {
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
if (!getGlobalOp)
return mlir::failure();
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
if (!denseAttr || !globalType || !globalType.hasStaticShape())
return mlir::failure();
auto elementType = denseAttr.getElementType();
if (!elementType.isIndex() && !elementType.isInteger())
return mlir::failure();
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::ConstantGlobalLoad;
expr.globalOp = globalOp;
expr.globalStrides = computeRowMajorStrides(globalType.getShape());
expr.operands.reserve(loadOp.getIndices().size());
for (mlir::Value index : loadOp.getIndices()) {
auto compiledIndex = compileIndexValueImpl(index);
if (failed(compiledIndex))
return mlir::failure();
expr.operands.push_back(*compiledIndex);
}
return makeCompiledIndexExpr(std::move(expr));
}
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value) {
if (auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>()) {
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue())) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = integerAttr.getInt();
return makeCompiledIndexExpr(std::move(expr));
}
}
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Symbol;
expr.symbol = value;
return makeCompiledIndexExpr(std::move(expr));
}
auto buildBinaryExpr = [&](CompiledIndexExprNode::Kind kind, mlir::Value lhsValue, mlir::Value rhsValue) {
auto lhs = compileIndexValueImpl(lhsValue);
auto rhs = compileIndexValueImpl(rhsValue);
if (failed(lhs) || failed(rhs))
return llvm::FailureOr<CompiledIndexExpr>(mlir::failure());
CompiledIndexExprNode expr;
expr.kind = kind;
expr.operands = {*lhs, *rhs};
return llvm::FailureOr<CompiledIndexExpr>(makeCompiledIndexExpr(std::move(expr)));
};
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
return compileIndexValueImpl(indexCastOp.getIn());
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Add, addOp.getLhs(), addOp.getRhs());
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Sub, subOp.getLhs(), subOp.getRhs());
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::Mul, mulOp.getLhs(), mulOp.getRhs());
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivUI, divOp.getLhs(), divOp.getRhs());
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivSI, divOp.getLhs(), divOp.getRhs());
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::MinUI, minOp.getLhs(), minOp.getRhs());
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemUI, remOp.getLhs(), remOp.getRhs());
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp))
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemSI, remOp.getLhs(), remOp.getRhs());
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
auto expr = buildBinaryExpr(CompiledIndexExprNode::Kind::CmpI, cmpOp.getLhs(), cmpOp.getRhs());
if (failed(expr))
return mlir::failure();
auto exprNode = std::make_shared<CompiledIndexExprNode>(*expr->node);
exprNode->predicate = cmpOp.getPredicate();
return CompiledIndexExpr(exprNode);
}
if (auto maxOp = mlir::dyn_cast<mlir::arith::MaxUIOp>(definingOp)) {
auto lhs = compileIndexValueImpl(maxOp.getLhs());
auto rhs = compileIndexValueImpl(maxOp.getRhs());
if (failed(lhs) || failed(rhs))
return mlir::failure();
CompiledIndexExprNode cmpExpr;
cmpExpr.kind = CompiledIndexExprNode::Kind::CmpI;
cmpExpr.predicate = mlir::arith::CmpIPredicate::uge;
cmpExpr.operands = {*lhs, *rhs};
CompiledIndexExprNode selectExpr;
selectExpr.kind = CompiledIndexExprNode::Kind::Select;
selectExpr.operands = {makeCompiledIndexExpr(std::move(cmpExpr)), *lhs, *rhs};
return makeCompiledIndexExpr(std::move(selectExpr));
}
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
auto condition = compileIndexValueImpl(selectOp.getCondition());
auto trueValue = compileIndexValueImpl(selectOp.getTrueValue());
auto falseValue = compileIndexValueImpl(selectOp.getFalseValue());
if (failed(condition) || failed(trueValue) || failed(falseValue))
return mlir::failure();
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Select;
expr.operands = {*condition, *trueValue, *falseValue};
return makeCompiledIndexExpr(std::move(expr));
}
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
return compileConstantGlobalLoad(loadOp);
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Symbol;
expr.symbol = value;
return makeCompiledIndexExpr(std::move(expr));
}
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
value = resolveAlias(value, knowledge);
@@ -110,6 +404,16 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
}
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
return mlir::failure();
return *lhs / *rhs;
}
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
@@ -126,6 +430,34 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
}
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs) || *rhs == 0)
return mlir::failure();
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
return 0;
return *lhs % *rhs;
}
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
auto lhs = resolveIndexValueImpl(cmpOp.getLhs(), knowledge);
auto rhs = resolveIndexValueImpl(cmpOp.getRhs(), knowledge);
if (failed(lhs) || failed(rhs))
return mlir::failure();
return evaluateCmpPredicate(cmpOp.getPredicate(), *lhs, *rhs) ? 1 : 0;
}
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
auto condition = resolveIndexValueImpl(selectOp.getCondition(), knowledge);
if (failed(condition))
return mlir::failure();
return resolveIndexValueImpl(*condition != 0 ? selectOp.getTrueValue() : selectOp.getFalseValue(), knowledge);
}
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
return resolveConstantGlobalLoad(loadOp, knowledge);
return mlir::failure();
}
@@ -217,8 +549,10 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
return mlir::failure();
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(sourceStrides))
return mlir::failure();
byteOffset += linearizeIndex(offsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
value = resolveAlias(subviewOp.getSource(), knowledge);
continue;
}
@@ -243,17 +577,206 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
}
}
} // namespace
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value) {
int64_t constantByteOffset = 0;
CompiledIndexExpr byteOffsetExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = 0;
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
}
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
while (true) {
if (mlir::isa<mlir::BlockArgument>(value))
return CompiledAddressExpr {value, byteOffsetExpr};
mlir::Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return mlir::failure();
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
if (!tiedOperand)
return mlir::failure();
value = tiedOperand->get();
continue;
}
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
auto result = mlir::dyn_cast<mlir::OpResult>(value);
if (!result)
return mlir::failure();
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
mlir::Value yieldedValue = yieldOp.getOperand(result.getResultNumber());
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
value = forOp.getInitArgs()[blockArgument.getArgNumber() - 1];
continue;
}
}
value = yieldedValue;
continue;
}
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
return mlir::failure();
llvm::SmallVector<int64_t> staticSizes;
staticSizes.reserve(subviewOp.getMixedSizes().size());
llvm::SmallVector<int64_t> staticStrides;
staticStrides.reserve(subviewOp.getMixedStrides().size());
llvm::SmallVector<int64_t> staticOffsets;
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
bool hasOnlyStaticOffsets = true;
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets())
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
else
hasOnlyStaticOffsets = false;
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
auto attr = mlir::dyn_cast<mlir::Attribute>(size);
if (!attr)
return mlir::failure();
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
}
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
auto attr = mlir::dyn_cast<mlir::Attribute>(stride);
if (!attr)
return mlir::failure();
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
}
if (!isContiguousSubviewWithDynamicOffsets(
sourceType.getShape(), subviewOp.getMixedOffsets(), staticSizes, staticStrides)) {
return mlir::failure();
}
if (hasOnlyStaticOffsets) {
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
return mlir::failure();
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(sourceStrides))
return mlir::failure();
constantByteOffset +=
linearizeIndex(staticOffsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
}
else {
auto sourceStrides = getStaticMemRefStrides(sourceType);
if (failed(sourceStrides))
return mlir::failure();
CompiledIndexExpr offsetExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = 0;
offsetExpr = makeCompiledIndexExpr(std::move(expr));
}
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) {
CompiledIndexExpr operandExpr;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(mixedOffset)) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = mlir::cast<mlir::IntegerAttr>(attr).getInt() * sourceStride
* getElementTypeSizeInBytes(subviewType.getElementType());
operandExpr = makeCompiledIndexExpr(std::move(expr));
}
else {
auto compiledOffset = compileIndexValueImpl(mlir::cast<mlir::Value>(mixedOffset));
if (failed(compiledOffset))
return mlir::failure();
CompiledIndexExpr scaleExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = sourceStride * getElementTypeSizeInBytes(subviewType.getElementType());
scaleExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Mul;
expr.operands = {*compiledOffset, scaleExpr};
operandExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {offsetExpr, operandExpr};
offsetExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExpr constantExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = constantByteOffset;
constantExpr = makeCompiledIndexExpr(std::move(expr));
}
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {constantExpr, offsetExpr};
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
constantByteOffset = 0;
}
value = subviewOp.getSource();
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
value = castOp.getSource();
continue;
}
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
value = collapseOp.getSrc();
continue;
}
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
value = expandOp.getSrc();
continue;
}
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp)) {
if (constantByteOffset != 0) {
CompiledIndexExpr constantExpr;
{
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = constantByteOffset;
constantExpr = makeCompiledIndexExpr(std::move(expr));
}
if (byteOffsetExpr.node->kind == CompiledIndexExprNode::Kind::Constant && byteOffsetExpr.node->constant == 0)
byteOffsetExpr = constantExpr;
else {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Add;
expr.operands = {constantExpr, byteOffsetExpr};
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
}
}
return CompiledAddressExpr {value, byteOffsetExpr};
}
return mlir::failure();
}
}
} // namespace
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
return resolveIndexValueImpl(value, &knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
return resolveContiguousAddressImpl(value, nullptr);
}
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value) { return compileIndexValueImpl(value); }
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge) {
@@ -264,4 +787,21 @@ mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledg
return resolveLoopCarriedAliasImpl(value, &knowledge);
}
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value) {
return compileContiguousAddressExprImpl(value);
}
llvm::FailureOr<int64_t> CompiledIndexExpr::evaluate(const StaticValueKnowledge& knowledge) const {
return evaluateCompiledIndexExpr(*this, knowledge);
}
llvm::FailureOr<ResolvedContiguousAddress> CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge,
std::optional<unsigned> lane) const {
(void) lane;
auto resolvedOffset = byteOffset.evaluate(knowledge);
if (failed(resolvedOffset))
return mlir::failure();
return ResolvedContiguousAddress {base, *resolvedOffset};
}
} // namespace onnx_mlir
+55 -4
View File
@@ -1,10 +1,14 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include <memory>
#include <optional>
namespace onnx_mlir {
/// Describes a value as a base addressable object plus a statically known
@@ -23,21 +27,68 @@ struct StaticValueKnowledge {
StaticValueKnowledge() {}
};
struct CompiledIndexExprNode;
struct CompiledIndexExpr {
std::shared_ptr<CompiledIndexExprNode> node;
CompiledIndexExpr() = default;
explicit CompiledIndexExpr(std::shared_ptr<CompiledIndexExprNode> node)
: node(std::move(node)) {}
llvm::FailureOr<int64_t> evaluate(const StaticValueKnowledge& knowledge) const;
};
struct CompiledIndexExprNode {
enum class Kind {
Constant,
Symbol,
Add,
Sub,
Mul,
DivUI,
DivSI,
RemUI,
RemSI,
MinUI,
CmpI,
Select,
ConstantGlobalLoad
};
Kind kind = Kind::Constant;
int64_t constant = 0;
mlir::Value symbol;
mlir::arith::CmpIPredicate predicate = mlir::arith::CmpIPredicate::eq;
mlir::memref::GlobalOp globalOp;
llvm::SmallVector<int64_t, 4> globalStrides;
llvm::SmallVector<CompiledIndexExpr, 4> operands;
};
struct CompiledAddressExpr {
mlir::Value base;
CompiledIndexExpr byteOffset;
llvm::FailureOr<ResolvedContiguousAddress> evaluate(const StaticValueKnowledge& knowledge,
std::optional<unsigned> lane) const;
};
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
/// Resolves a value to contiguous backing storage when that storage can be
/// proven statically from aliases, DPS ties, casts, and subviews.
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
const StaticValueKnowledge& knowledge);
const StaticValueKnowledge& knowledge = {});
/// Statically evaluates index-like SSA values, including simple integer
/// arithmetic and loop facts recorded in `knowledge`.
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {});
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value);
/// Follows alias, view, and DPS chains to recover the backing value of a
/// loop-carried memref/result.
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value);
} // namespace onnx_mlir
+182
View File
@@ -0,0 +1,182 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Matchers.h"
#include "AffineUtils.hpp"
#include "ConstantUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
static FailureOr<int64_t> floorDivSigned(int64_t lhs, int64_t rhs) {
if (rhs <= 0)
return failure();
int64_t quotient = lhs / rhs;
int64_t remainder = lhs % rhs;
if (remainder != 0 && lhs < 0)
--quotient;
return quotient;
}
static FailureOr<int64_t> ceilDivSigned(int64_t lhs, int64_t rhs) {
if (rhs <= 0)
return failure();
int64_t quotient = lhs / rhs;
int64_t remainder = lhs % rhs;
if (remainder != 0 && lhs > 0)
++quotient;
return quotient;
}
Value createOrFoldAffineApply(
RewriterBase& rewriter, Location loc, AffineMap map, ValueRange operands, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
assert(map.getNumResults() == 1 && "affine.apply expects a single-result affine map");
SmallVector<Attribute> operandConstants;
operandConstants.reserve(operands.size());
for (Value operand : operands) {
std::optional<int64_t> constantValue = matchConstantIndexValue(operand);
if (!constantValue)
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
operandConstants.push_back(rewriter.getIndexAttr(*constantValue));
}
SmallVector<Attribute> foldedResults;
if (succeeded(map.constantFold(operandConstants, foldedResults)) && foldedResults.size() == 1)
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
return getOrCreateIndexConstant(rewriter, constantAnchor, constantResult.getInt());
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
}
Value createOrFoldAffineApply(
RewriterBase& rewriter, Location loc, AffineExpr expr, ValueRange dims, Operation* constantAnchor) {
AffineMap map = AffineMap::get(/*dimCount=*/dims.size(), /*symbolCount=*/0, expr);
return createOrFoldAffineApply(rewriter, loc, map, dims, constantAnchor);
}
Value affineMulConst(RewriterBase& rewriter, Location loc, Value value, int64_t multiplier, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
if (multiplier == 0)
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
if (multiplier == 1)
return value;
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
return createOrFoldAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value}, constantAnchor);
}
Value affineModConst(RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
assert(divisor > 0 && "expected a positive affine.mod divisor");
if (divisor == 1)
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
return createOrFoldAffineApply(rewriter, loc, d0 % divisor, ValueRange {value}, constantAnchor);
}
Value affineFloorDivConst(
RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
assert(constantAnchor && "expected a valid constant anchor");
assert(divisor > 0 && "expected a positive affine.floor_div divisor");
if (divisor == 1)
return value;
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
}
FailureOr<int64_t> evaluateAffineExpr(AffineExpr expr, ArrayRef<int64_t> dims, ArrayRef<int64_t> symbols) {
if (auto constant = dyn_cast<AffineConstantExpr>(expr))
return constant.getValue();
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
unsigned position = dim.getPosition();
if (position >= dims.size())
return failure();
return dims[position];
}
if (auto symbol = dyn_cast<AffineSymbolExpr>(expr)) {
unsigned position = symbol.getPosition();
if (position >= symbols.size())
return failure();
return symbols[position];
}
auto binary = dyn_cast<AffineBinaryOpExpr>(expr);
if (!binary)
return failure();
FailureOr<int64_t> lhs = evaluateAffineExpr(binary.getLHS(), dims, symbols);
FailureOr<int64_t> rhs = evaluateAffineExpr(binary.getRHS(), dims, symbols);
if (failed(lhs) || failed(rhs))
return failure();
switch (binary.getKind()) {
case AffineExprKind::Add: return *lhs + *rhs;
case AffineExprKind::Mul: return *lhs * *rhs;
case AffineExprKind::FloorDiv: return floorDivSigned(*lhs, *rhs);
case AffineExprKind::CeilDiv: return ceilDivSigned(*lhs, *rhs);
case AffineExprKind::Mod: {
FailureOr<int64_t> div = floorDivSigned(*lhs, *rhs);
if (failed(div))
return failure();
return *lhs - *div * *rhs;
}
default: return failure();
}
}
FailureOr<int64_t> evaluateSingleResultAffineMap(AffineMap map, ArrayRef<int64_t> operands) {
if (map.getNumResults() != 1 || operands.size() != map.getNumInputs())
return failure();
ArrayRef<int64_t> dims(operands.data(), map.getNumDims());
ArrayRef<int64_t> symbols(operands.data() + map.getNumDims(), map.getNumSymbols());
return evaluateAffineExpr(map.getResult(0), dims, symbols);
}
FailureOr<int64_t> evaluateAffineApply(affine::AffineApplyOp affineApply, IndexValueResolver resolver) {
SmallVector<int64_t, 4> operands;
operands.reserve(affineApply.getMapOperands().size());
for (Value operand : affineApply.getMapOperands()) {
FailureOr<int64_t> folded = resolver(operand);
if (failed(folded))
return failure();
operands.push_back(*folded);
}
return evaluateSingleResultAffineMap(affineApply.getAffineMap(), operands);
}
bool isSingleResultSymbolFreeAffineMap(AffineMap map) { return map.getNumResults() == 1 && map.getNumSymbols() == 0; }
bool isDimAndConstantAffineExpr(AffineExpr expr) {
switch (expr.getKind()) {
case AffineExprKind::Constant:
case AffineExprKind::DimId: return true;
case AffineExprKind::SymbolId: return false;
case AffineExprKind::Add: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isDimAndConstantAffineExpr(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS());
}
case AffineExprKind::Mul: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return (isa<AffineConstantExpr>(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS()))
|| (isa<AffineConstantExpr>(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS()));
}
case AffineExprKind::FloorDiv:
case AffineExprKind::CeilDiv:
case AffineExprKind::Mod: {
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
return isa<AffineConstantExpr>(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS());
}
}
llvm_unreachable("unexpected affine expression kind");
}
} // namespace onnx_mlir
+55
View File
@@ -0,0 +1,55 @@
#pragma once
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/FunctionExtras.h"
namespace onnx_mlir {
using IndexValueResolver = llvm::function_ref<llvm::FailureOr<int64_t>(mlir::Value)>;
mlir::Value createOrFoldAffineApply(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::AffineMap map,
mlir::ValueRange operands,
mlir::Operation* constantAnchor);
mlir::Value createOrFoldAffineApply(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::AffineExpr expr,
mlir::ValueRange dims,
mlir::Operation* constantAnchor);
mlir::Value affineMulConst(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::Value value,
int64_t multiplier,
mlir::Operation* constantAnchor);
mlir::Value affineModConst(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::Value value,
int64_t divisor,
mlir::Operation* constantAnchor);
mlir::Value affineFloorDivConst(mlir::RewriterBase& rewriter,
mlir::Location loc,
mlir::Value value,
int64_t divisor,
mlir::Operation* constantAnchor);
llvm::FailureOr<int64_t>
evaluateAffineExpr(mlir::AffineExpr expr, llvm::ArrayRef<int64_t> dims, llvm::ArrayRef<int64_t> symbols = {});
llvm::FailureOr<int64_t> evaluateSingleResultAffineMap(mlir::AffineMap map, llvm::ArrayRef<int64_t> operands);
llvm::FailureOr<int64_t> evaluateAffineApply(mlir::affine::AffineApplyOp affineApply, IndexValueResolver resolver);
bool isSingleResultSymbolFreeAffineMap(mlir::AffineMap map);
bool isDimAndConstantAffineExpr(mlir::AffineExpr expr);
} // namespace onnx_mlir
+32
View File
@@ -0,0 +1,32 @@
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
namespace onnx_mlir {
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<mlir::DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
llvm::SmallVector<int32_t> laneCoreIds;
laneCoreIds.reserve(coreIds.size() / laneCount);
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
return laneCoreIds;
}
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex) {
if (mlir::isa<pim::PimMemCopyHostToDevOp>(op))
return operandIndex == 3;
if (mlir::isa<pim::PimMemCopyDevToHostOp>(op))
return operandIndex == 2;
return false;
}
bool isExplicitDevToHostTargetOperand(mlir::Operation* op, unsigned operandIndex) {
return mlir::isa<pim::PimMemCopyDevToHostOp>(op) && operandIndex == 2;
}
} // namespace onnx_mlir
+18
View File
@@ -0,0 +1,18 @@
#pragma once
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex);
bool isExplicitDevToHostTargetOperand(mlir::Operation* op, unsigned operandIndex);
} // namespace onnx_mlir
+93
View File
@@ -0,0 +1,93 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dialect.h"
#include "ConstantUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
Block* getConstantInsertionBlock(Operation* anchorOp) {
assert(anchorOp && "expected a valid anchor operation");
if (auto funcOp = dyn_cast<func::FuncOp>(anchorOp))
return &funcOp.getBody().front();
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
return &funcOp.getBody().front();
if (auto moduleOp = dyn_cast<ModuleOp>(anchorOp))
return moduleOp.getBody();
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
return moduleOp.getBody();
return anchorOp->getBlock();
}
Value getOrCreateConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) {
assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getConstantInsertionBlock(anchorOp);
for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
continue;
return constantOp.getResult();
}
auto* arithDialect = anchorOp->getContext()->getOrLoadDialect<arith::ArithDialect>();
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
}
Value getOrCreateConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) {
assert(anchorOp && "expected a valid anchor operation");
Block* hostBlock = getConstantInsertionBlock(anchorOp);
for (Operation& op : *hostBlock) {
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
continue;
return constantOp.getResult();
}
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(hostBlock);
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
}
Value getOrCreateConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) {
return getOrCreateConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType());
}
Value getOrCreateIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) {
Builder builder(anchorOp->getContext());
return getOrCreateConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
}
Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) {
Builder builder(anchorOp->getContext());
return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
}
std::optional<int64_t> matchConstantIndexValue(Value value) {
if (!value || !value.getType().isIndex())
return std::nullopt;
if (auto constant = value.getDefiningOp<arith::ConstantIndexOp>())
return constant.value();
if (auto constant = value.getDefiningOp<arith::ConstantOp>())
if (auto intAttr = dyn_cast<IntegerAttr>(constant.getValue()); intAttr && intAttr.getType().isIndex())
return intAttr.getInt();
return std::nullopt;
}
std::optional<int64_t> matchConstantIndexValue(OpFoldResult value) {
if (auto attr = dyn_cast<Attribute>(value))
if (auto intAttr = dyn_cast<IntegerAttr>(attr); intAttr && intAttr.getType().isIndex())
return intAttr.getInt();
if (auto operand = dyn_cast<Value>(value))
return matchConstantIndexValue(operand);
return std::nullopt;
}
} // namespace onnx_mlir
+30
View File
@@ -0,0 +1,30 @@
#pragma once
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/FoldUtils.h"
#include <optional>
namespace onnx_mlir {
mlir::Block* getConstantInsertionBlock(mlir::Operation* anchorOp);
mlir::Value
getOrCreateConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, mlir::Attribute value, mlir::Type type);
mlir::Value
getOrCreateConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, mlir::Attribute value, mlir::Type type);
mlir::Value getOrCreateConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp);
mlir::Value getOrCreateIndexConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value);
mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value);
std::optional<int64_t> matchConstantIndexValue(mlir::Value value);
std::optional<int64_t> matchConstantIndexValue(mlir::OpFoldResult value);
} // namespace onnx_mlir
+71 -2
View File
@@ -1,25 +1,37 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
bool isCoreStaticAddressOp(mlir::Operation* op) {
return mlir::isa<mlir::arith::ConstantOp,
if (mlir::isa<mlir::arith::ConstantOp,
mlir::arith::AddIOp,
mlir::arith::SubIOp,
mlir::arith::MulIOp,
mlir::arith::DivUIOp,
mlir::arith::DivSIOp,
mlir::arith::MinUIOp,
mlir::arith::RemUIOp,
mlir::arith::RemSIOp,
mlir::arith::IndexCastOp,
mlir::arith::CmpIOp,
mlir::memref::AllocOp,
mlir::memref::SubViewOp,
mlir::memref::CastOp,
mlir::memref::CollapseShapeOp,
mlir::memref::ExpandShapeOp>(op);
mlir::memref::ExpandShapeOp>(op))
return true;
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(op))
return selectOp.getType().isIntOrIndex();
return false;
}
mlir::LogicalResult
@@ -30,6 +42,9 @@ walkPimCoreBlock(mlir::Block& block,
for (mlir::Operation& op : block) {
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
continue;
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
mlir::Block& loopBody = forOp.getRegion().front();
@@ -65,4 +80,58 @@ walkPimCoreBlock(mlir::Block& block,
return mlir::success(!hasFailure);
}
mlir::LogicalResult walkPimCoreBlockStructurally(
mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
bool hasFailure = false;
for (mlir::Operation& op : block) {
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
continue;
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
continue;
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
mlir::Block& loopBody = forOp.getRegion().front();
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
auto step = resolveIndexValue(forOp.getStep(), knowledge);
if (failed(lowerBound) || failed(upperBound) || failed(step)) {
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM verification");
hasFailure = true;
continue;
}
if (*step <= 0) {
forOp.emitOpError("requires positive scf.for step for PIM verification");
hasFailure = true;
continue;
}
llvm::SmallVector<int64_t, 2> samples;
if (*lowerBound < *upperBound) {
samples.push_back(*lowerBound);
int64_t last = *lowerBound + ((*upperBound - 1 - *lowerBound) / *step) * *step;
if (last != *lowerBound)
samples.push_back(last);
}
for (int64_t inductionValue : samples) {
StaticValueKnowledge loopKnowledge = knowledge;
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), forOp.getInitArgs()))
loopKnowledge.aliases[iterArg] = iterValue;
if (failed(walkPimCoreBlockStructurally(loopBody, loopKnowledge, callback)))
hasFailure = true;
}
continue;
}
if (failed(callback(op, knowledge)))
hasFailure = true;
}
return mlir::success(!hasFailure);
}
} // namespace onnx_mlir
+8
View File
@@ -21,4 +21,12 @@ walkPimCoreBlock(mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
/// Walks a `pim.core`-like body structurally for verification without
/// enumerating full loop trip counts. Loop bounds must still be statically
/// evaluable so address resolution remains well-defined.
mlir::LogicalResult walkPimCoreBlockStructurally(
mlir::Block& block,
const StaticValueKnowledge& knowledge,
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
} // namespace onnx_mlir
+77
View File
@@ -1,4 +1,5 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
@@ -35,6 +36,30 @@ int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
return numElements;
}
bool hasByteSizedElementType(mlir::Type elementType) {
if (mlir::isa<mlir::IndexType>(elementType))
return true;
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
return intType.getWidth() > 0 && intType.getWidth() % 8 == 0;
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
return floatType.getWidth() > 0 && floatType.getWidth() % 8 == 0;
return false;
}
size_t getElementTypeSizeInBytes(mlir::Type elementType) {
if (mlir::isa<mlir::IndexType>(elementType))
return mlir::IndexType::kInternalStorageBitWidth / 8;
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
return static_cast<size_t>(intType.getWidth() / 8);
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
return static_cast<size_t>(floatType.getWidth() / 8);
llvm_unreachable("expected byte-sized integer, float, or index element type");
}
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType) {
return static_cast<size_t>(shapedType.getNumElements()) * getElementTypeSizeInBytes(shapedType.getElementType());
}
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
@@ -86,4 +111,56 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
return true;
}
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
llvm::ArrayRef<int64_t> staticSizes,
llvm::ArrayRef<int64_t> staticStrides) {
if (sourceShape.size() != mixedOffsets.size() || sourceShape.size() != staticSizes.size()
|| sourceShape.size() != staticStrides.size()) {
return false;
}
if (llvm::any_of(staticStrides, [](int64_t stride) { return stride != 1; }))
return false;
auto reversedTriples =
llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(mixedOffsets), llvm::reverse(staticSizes));
auto firstNonZeroOrDynamicOffset = llvm::find_if(reversedTriples, [](auto triple) {
auto [_sourceDim, offset, _size] = triple;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
return mlir::cast<mlir::IntegerAttr>(attr).getInt() != 0;
return true;
});
if (firstNonZeroOrDynamicOffset != reversedTriples.end()) {
auto [sourceDim, offset, size] = *firstNonZeroOrDynamicOffset;
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
int64_t staticOffset = mlir::cast<mlir::IntegerAttr>(attr).getInt();
if (size > sourceDim - staticOffset)
return false;
}
++firstNonZeroOrDynamicOffset;
for (auto it = firstNonZeroOrDynamicOffset; it != reversedTriples.end(); ++it)
if (std::get<2>(*it) != 1)
return false;
}
auto reversedSizes = llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(staticSizes));
auto firstDifferentSize = llvm::find_if(reversedSizes, [](auto pair) {
auto [sourceDim, size] = pair;
return size != sourceDim;
});
if (firstDifferentSize != reversedSizes.end()) {
++firstDifferentSize;
for (auto it = firstDifferentSize; it != reversedSizes.end(); ++it)
if (std::get<1>(*it) != 1)
return false;
}
return true;
}
} // namespace onnx_mlir
+17
View File
@@ -1,8 +1,14 @@
#pragma once
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
namespace onnx_mlir {
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
@@ -14,9 +20,20 @@ int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t>
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
bool hasByteSizedElementType(mlir::Type elementType);
size_t getElementTypeSizeInBytes(mlir::Type elementType);
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
llvm::ArrayRef<int64_t> offsets,
llvm::ArrayRef<int64_t> sizes,
llvm::ArrayRef<int64_t> strides);
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
llvm::ArrayRef<int64_t> staticSizes,
llvm::ArrayRef<int64_t> staticStrides);
} // namespace onnx_mlir
+22
View File
@@ -31,6 +31,19 @@ Value stripMemRefViewOps(Value value) {
}
}
Value stripMemRefAddressingOps(Value value) {
while (true) {
if (auto subviewOp = value.getDefiningOp<memref::SubViewOp>()) {
value = subviewOp.getSource();
continue;
}
Value strippedValue = stripMemRefViewOps(value);
if (strippedValue == value)
return value;
value = strippedValue;
}
}
bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
@@ -81,4 +94,13 @@ FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo&
return staticOffsets;
}
bool isMemRefBaseAddressableValue(Value value) {
value = stripMemRefAddressingOps(value);
if (isa<BlockArgument>(value))
return true;
Operation* defOp = value.getDefiningOp();
return defOp && isa<memref::AllocOp, memref::GetGlobalOp>(defOp);
}
} // namespace onnx_mlir
+4
View File
@@ -20,6 +20,8 @@ mlir::Value stripMemRefCasts(mlir::Value value);
mlir::Value stripMemRefViewOps(mlir::Value value);
mlir::Value stripMemRefAddressingOps(mlir::Value value);
bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview);
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
@@ -27,4 +29,6 @@ llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
bool isMemRefBaseAddressableValue(mlir::Value value);
} // namespace onnx_mlir
+188 -24
View File
@@ -1,8 +1,14 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallSet.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -19,29 +25,57 @@ void markWeightAlways(mlir::Operation* op) {
namespace {
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
CompiledIndexExpr makeConstantExpr(int64_t constant) {
CompiledIndexExprNode expr;
expr.kind = CompiledIndexExprNode::Kind::Constant;
expr.constant = constant;
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::move(expr)));
}
CompiledIndexExpr makeBinaryExpr(CompiledIndexExprNode::Kind kind, CompiledIndexExpr lhs, CompiledIndexExpr rhs) {
CompiledIndexExprNode expr;
expr.kind = kind;
expr.operands = {std::move(lhs), std::move(rhs)};
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::move(expr)));
}
CompiledIndexExpr addExpr(CompiledIndexExpr lhs, CompiledIndexExpr rhs) {
return makeBinaryExpr(CompiledIndexExprNode::Kind::Add, std::move(lhs), std::move(rhs));
}
CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
}
template <typename VMMOpTy, typename ParentOpTy>
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
auto weightArg = parentOp.getWeightArgument(weightIndex);
if (!weightArg)
return false;
bool found = false;
parentOp.walk([&](mlir::Operation* op) {
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
found |= mvmOp.getWeightIndex() == weightIndex;
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeightIndex() == weightIndex;
if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
found |= vmmOp.getWeight() == *weightArg;
});
return found;
}
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
template <typename VMMOpTy, typename ParentOpTy>
void walkVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
auto weights = parentOp.getWeights();
llvm::SmallSet<unsigned, 8> visited;
auto walkWeightIndex = [&](unsigned weightIndex) {
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
auto walkWeight = [&](mlir::Value weight) {
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
auto weightArg = parentOp.getWeightArgument(weightIndex);
if (!weightArg || *weightArg != weight)
continue;
if (visited.insert(weightIndex).second)
callback(parentOp->getOpOperand(weightIndex));
break;
}
};
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
}
} // namespace
@@ -54,7 +88,7 @@ bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
if (!computeOp || operandIndex >= computeOp.getWeights().size())
return false;
return hasMvmVmmWeightUse<spatial::SpatMVMOp, spatial::SpatVMMOp>(computeOp, operandIndex);
return hasVmmWeightUse<spatial::SpatVMMOp>(computeOp, operandIndex);
}
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
@@ -76,8 +110,8 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(user))
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
if (auto transposeOp = mlir::dyn_cast<mlir::linalg::TransposeOp>(user))
return transposeOp.getInput() == currentValue && self(transposeOp.getResult()[0], self);
return false;
});
@@ -90,19 +124,149 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
assert(root && "expected valid root op");
root->walk([&](pim::PimCoreOp coreOp) {
coreOp.walk([&](pim::PimVMMOp vmmOp) {
auto weights = coreOp.getWeights();
unsigned weightIndex = vmmOp.getWeightIndex();
if (weightIndex < weights.size())
callback(coreOp->getOpOperand(weightIndex));
if (auto weightIndex = resolveWeightIndex(coreOp.getOperation(), vmmOp.getWeight()))
callback(coreOp->getOpOperand(*weightIndex));
});
});
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
auto weights = coreBatchOp.getWeights();
for (auto weight : weights)
for (mlir::OpOperand& use : weight.getUses())
if (use.getOwner() == coreBatchOp.getOperation())
callback(use);
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
if (auto weightIndex = resolveWeightIndex(coreBatchOp.getOperation(), vmmOp.getWeight()))
callback(coreBatchOp->getOpOperand(*weightIndex));
});
});
}
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) {
weight = stripMemRefAddressingOps(weight);
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
if (coreOp.getWeightArgument(weightIndex) == weight)
return weightIndex;
return std::nullopt;
}
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
if (coreBatchOp.getWeightArgument(weightIndex) == weight)
return weightIndex;
return std::nullopt;
}
return std::nullopt;
}
llvm::FailureOr<ResolvedWeightView>
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge) {
llvm::SmallVector<mlir::Operation*> viewOps;
mlir::Value current = weight;
while (true) {
if (auto defOp = current.getDefiningOp()) {
if (auto getGlobalOp = mlir::dyn_cast<mlir::memref::GetGlobalOp>(defOp)) {
auto moduleOp = weightOwner ? weightOwner->getParentOfType<mlir::ModuleOp>() : mlir::ModuleOp {};
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getInitialValue())
return mlir::failure();
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return mlir::failure();
ResolvedWeightView view;
view.globalOp = globalOp;
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStrides(view.shape);
CompiledIndexExpr offsetExpr = makeConstantExpr(0);
for (mlir::Operation* viewOp : llvm::reverse(viewOps)) {
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(viewOp)) {
llvm::SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getMixedOffsets().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) {
CompiledIndexExpr offsetValue = makeConstantExpr(0);
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
if (!intAttr)
return mlir::failure();
offsetValue = makeConstantExpr(intAttr.getInt());
}
else if (auto value = mlir::dyn_cast<mlir::Value>(offset)) {
auto compiledOffset = compileIndexExpr(value);
if (failed(compiledOffset))
return mlir::failure();
offsetValue = *compiledOffset;
}
else {
return mlir::failure();
}
offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride));
nextStrides.push_back(stride * sourceStride);
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
continue;
}
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return mlir::failure();
auto resultType = mlir::cast<mlir::MemRefType>(collapse.getResult().getType());
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
continue;
}
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return mlir::failure();
auto resultType = mlir::cast<mlir::MemRefType>(expand.getResult().getType());
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
}
}
auto resolvedOffset = offsetExpr.evaluate(knowledge);
if (failed(resolvedOffset))
return mlir::failure();
view.offset = *resolvedOffset;
return view;
}
if (mlir::isa<mlir::memref::SubViewOp, mlir::memref::CollapseShapeOp, mlir::memref::ExpandShapeOp>(defOp)) {
viewOps.push_back(defOp);
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp))
current = subview.getSource();
else if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp))
current = collapse.getSrc();
else
current = mlir::cast<mlir::memref::ExpandShapeOp>(defOp).getSrc();
continue;
}
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(defOp)) {
current = castOp.getSource();
continue;
}
return mlir::failure();
}
auto weightIndex = resolveWeightIndex(weightOwner, current);
if (!weightIndex)
return mlir::failure();
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
current = coreOp.getWeights()[*weightIndex];
continue;
}
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
current = coreBatchOp.getWeights()[*weightIndex];
continue;
}
return mlir::failure();
}
}
} // namespace onnx_mlir
+36 -1
View File
@@ -1,15 +1,34 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <optional>
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
namespace onnx_mlir {
struct ResolvedWeightView {
mlir::memref::GlobalOp globalOp;
llvm::SmallVector<int64_t> shape;
llvm::SmallVector<int64_t> strides;
int64_t offset = 0;
bool operator==(const ResolvedWeightView& other) const {
return globalOp == other.globalOp && shape == other.shape && strides == other.strides && offset == other.offset;
}
};
bool hasWeightAlways(mlir::Operation* op);
/// Tags an op as producing a value that should stay materialized as a reusable
@@ -26,4 +45,20 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
/// passes can identify globals that must remain weight-backed.
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight);
llvm::FailureOr<ResolvedWeightView>
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge = {});
template <typename CoreLikeOpTy>
llvm::SmallVector<unsigned, 8> getUsedWeightIndices(CoreLikeOpTy coreLikeOp) {
llvm::SmallVector<unsigned, 8> indices;
coreLikeOp.walk([&](pim::PimVMMOp vmmOp) {
auto weightIndex = resolveWeightIndex(coreLikeOp.getOperation(), vmmOp.getWeight());
if (weightIndex && !llvm::is_contained(indices, *weightIndex))
indices.push_back(*weightIndex);
});
llvm::sort(indices);
return indices;
}
} // namespace onnx_mlir
+1
View File
@@ -12,6 +12,7 @@
#include "llvm/ADT/StringRef.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
+1 -1
View File
@@ -18,7 +18,7 @@ void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
llvm::raw_os_ostream os(file);
mlir::OpPrintingFlags flags;
flags.elideLargeElementsAttrs();
flags.elideLargeElementsAttrs().enableDebugInfo(true, false);
moduleOp.print(os, flags);
os.flush();
file.close();
+3 -3
View File
@@ -13,7 +13,8 @@
namespace onnx_mlir::pim {
struct CappedDiagnosticReporter {
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8)
: maxReportedFailures(maxReportedFailures) {}
template <typename EmitFn>
void report(mlir::Operation* op, EmitFn&& emit) {
@@ -24,8 +25,7 @@ struct CappedDiagnosticReporter {
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
if (numFailures > maxReportedFailures)
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
<< failureDescription;
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
}
bool hasFailure() const { return numFailures != 0; }
-1
View File
@@ -16,7 +16,6 @@ add_pim_library(OMPimCompilerOptions
add_pim_library(OMPimCompilerUtils
PimCompilerUtils.cpp
PimArtifactWriter.cpp
PimBatchEmission.cpp
PimCodeGen.cpp
PimWeightEmitter.cpp
+1 -1
View File
@@ -48,7 +48,7 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
if (!denseAttr)
return;
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
MemEntry memEntry = memory.hostMem.getMemEntry({getGlobalOp.getResult(), std::nullopt});
ArrayRef<char> rawData = denseAttr.getRawData();
char* dst = memoryBuffer.data() + memEntry.address;
-136
View File
@@ -1,136 +0,0 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
}
static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
SmallVector<int32_t> laneCoreIds;
laneCoreIds.reserve(coreIds.size() / laneCount);
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
return laneCoreIds;
}
static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) {
IRRewriter rewriter(scalarCore.getContext());
SmallVector<Operation*> batchOps;
scalarCore.walk([&](Operation* op) {
if (isa<pim::PimSendBatchOp,
pim::PimSendTensorBatchOp,
pim::PimReceiveBatchOp,
pim::PimReceiveTensorBatchOp,
pim::PimMemCopyHostToDevBatchOp>(op)) {
batchOps.push_back(op);
}
});
for (Operation* op : batchOps) {
rewriter.setInsertionPoint(op);
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
pim::PimSendOp::create(rewriter,
sendBatchOp.getLoc(),
sendBatchOp.getInput(),
sendBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
rewriter.eraseOp(op);
continue;
}
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
pim::PimSendTensorOp::create(
rewriter,
sendTensorBatchOp.getLoc(),
sendTensorBatchOp.getInput(),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
rewriter.eraseOp(op);
continue;
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
auto scalarReceive =
pim::PimReceiveOp::create(rewriter,
receiveBatchOp.getLoc(),
receiveBatchOp.getOutput().getType(),
receiveBatchOp.getOutputBuffer(),
receiveBatchOp.getSizeAttr(),
rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
rewriter.replaceOp(op, scalarReceive->getResults());
continue;
}
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
auto scalarReceive = pim::PimReceiveTensorOp::create(
rewriter,
receiveTensorBatchOp.getLoc(),
receiveTensorBatchOp.getOutput().getType(),
receiveTensorBatchOp.getOutputBuffer(),
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
rewriter.replaceOp(op, scalarReceive->getResults());
continue;
}
auto memcpBatchOp = cast<pim::PimMemCopyHostToDevBatchOp>(op);
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
memcpBatchOp.getDeviceTarget(),
memcpBatchOp.getHostSource(),
memcpBatchOp.getDeviceTargetOffsetAttr(),
memcpBatchOp.getHostSourceOffsetAttr(),
memcpBatchOp.getSizeAttr());
rewriter.replaceOp(op, scalarCopy->getResults());
}
}
} // namespace
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
OpBuilder builder(scratchModule->getContext());
builder.setInsertionPointToStart(scratchModule->getBody());
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
SmallVector<Value> laneWeights;
laneWeights.reserve(weightsPerLane);
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
auto coreIds = getBatchCoreIds(coreBatchOp);
auto scalarCore = pim::PimCoreOp::create(
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
IRMapping mapper;
if (coreBatchOp.getBody().front().getNumArguments() == 1)
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
builder.setInsertionPointToEnd(block);
for (Operation& op : coreBatchOp.getBody().front()) {
Operation* cloned = builder.clone(op, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
}
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
scalarizeBatchOpsInCore(scalarCore, laneCount, lane);
return callback(scalarCore);
}
} // namespace onnx_mlir
-13
View File
@@ -1,13 +0,0 @@
#pragma once
#include "llvm/ADT/STLFunctionalExtras.h"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
namespace onnx_mlir {
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
} // namespace onnx_mlir
File diff suppressed because it is too large Load Diff
+60 -13
View File
@@ -4,13 +4,16 @@
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_os_ostream.h"
#include <fstream>
#include <limits>
#include <optional>
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
@@ -23,6 +26,13 @@ struct MemEntry {
size_t size;
};
struct MemoryValueKey {
mlir::Value value;
std::optional<unsigned> lane;
bool operator==(const MemoryValueKey& other) const { return value == other.value && lane == other.lane; }
};
struct MemoryReportRow {
uint64_t numAlloca = 0;
uint64_t sizeAlloca = 0;
@@ -50,33 +60,33 @@ struct MemoryReportEntry {
};
class PimMemory {
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> ownedMemEntriesMap;
llvm::SmallVector<std::pair<MemEntry, MemoryValueKey>, 32> memEntries;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> ownedMemEntriesMap;
size_t minAlignment = 4;
size_t firstAvailableAddress = 0;
MemEntry* gatherMemEntry(mlir::Value value);
MemEntry* gatherMemEntry(mlir::Value value, std::optional<unsigned> lane = std::nullopt);
void allocateGatheredMemory();
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
void allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry);
public:
PimMemory(llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap)
PimMemory(llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap)
: globalMemEntriesMap(globalMemEntriesMap) {}
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
void allocateCore(mlir::Operation* op);
void allocateCore(mlir::Operation* op, std::optional<unsigned> lane = std::nullopt);
MemoryReportRow getReportRow() const;
void remove(mlir::Value val);
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
MemEntry getMemEntry(mlir::Value value) const;
MemEntry getMemEntry(const MemoryValueKey& key) const;
};
class PimAcceleratorMemory {
public:
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> memEntriesMap;
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> memEntriesMap;
PimMemory hostMem;
private:
@@ -84,14 +94,23 @@ private:
std::fstream fileReport;
std::optional<MemoryReportRow> hostReportRow;
llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
mutable llvm::DenseMap<mlir::Value, CompiledIndexExpr> compiledIndexExprs;
mutable llvm::DenseMap<mlir::Value, CompiledAddressExpr> compiledAddressExprs;
public:
PimAcceleratorMemory()
: hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {}
PimAcceleratorMemory(const llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& initialMemEntries, bool enableReport)
: memEntriesMap(initialMemEntries),
hostMem(memEntriesMap),
fileReport(enableReport ? openReportFile("memory_report") : std::fstream()) {}
PimMemory& getOrCreateDeviceMem(size_t id);
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
size_t getValueAddress(mlir::Value value,
const StaticValueKnowledge& knowledge = {},
std::optional<unsigned> lane = std::nullopt) const;
llvm::FailureOr<int64_t> getIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
void reportHost();
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
void recordBatchReport(uint64_t batchId,
@@ -103,15 +122,24 @@ public:
void clean(mlir::Operation* op);
};
struct CoreEmissionJob {
mlir::Operation* coreLikeOp = nullptr;
size_t originalCoreId = 0;
size_t emittedCoreId = 0;
llvm::SmallVector<unsigned, 4> lanes;
std::optional<uint64_t> batchReportId;
};
class PimCodeGen {
PimAcceleratorMemory& memory;
llvm::raw_fd_ostream& coreBinaryStream;
llvm::raw_fd_ostream* coreJsonStream;
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
std::optional<unsigned> batchLane;
mutable uint32_t emittedInstructionCount = 0;
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getValueAddress(value, knowledge);
return memory.getValueAddress(value, knowledge, batchLane);
}
size_t remapCoreId(size_t coreId) const;
@@ -141,15 +169,17 @@ public:
: memory(memory), coreBinaryStream(coreBinary), coreJsonStream(coreJson), emittedCoreIds(emittedCoreIds) {}
uint32_t getEmittedInstructionCount() const { return emittedInstructionCount; }
void setBatchLane(std::optional<unsigned> lane) { batchLane = lane; }
llvm::FailureOr<int64_t> indexOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
return memory.getIndexValue(value, knowledge);
}
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
template <typename MVMTy>
@@ -172,3 +202,20 @@ public:
OnnxMlirCompilerErrorCodes compileToPimCode(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
} // namespace onnx_mlir
namespace llvm {
template <>
struct DenseMapInfo<onnx_mlir::MemoryValueKey> {
static onnx_mlir::MemoryValueKey getEmptyKey() { return {DenseMapInfo<mlir::Value>::getEmptyKey(), 0}; }
static onnx_mlir::MemoryValueKey getTombstoneKey() { return {DenseMapInfo<mlir::Value>::getTombstoneKey(), 0}; }
static unsigned getHashValue(const onnx_mlir::MemoryValueKey& key) {
return hash_combine(key.value, key.lane.value_or(std::numeric_limits<unsigned>::max()));
}
static bool isEqual(const onnx_mlir::MemoryValueKey& lhs, const onnx_mlir::MemoryValueKey& rhs) { return lhs == rhs; }
};
} // namespace llvm
+4 -11
View File
@@ -1,7 +1,7 @@
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "llvm/Support/ErrorHandling.h"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#define DEBUG_TYPE "PimCompilerOptions"
namespace onnx_mlir {
@@ -15,11 +15,10 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
llvm::cl::init(EmitPimCodegen),
llvm::cl::cat(OnnxMlirOptions));
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler(
"pim-merge-scheduler",
llvm::cl::opt<PimMergeSchedulerType>
pimMergeScheduler("pim-merge-scheduler",
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
llvm::cl::init(MergeSchedulerPeft),
llvm::cl::cat(OnnxMlirOptions));
@@ -49,12 +48,6 @@ llvm::cl::opt<long> coresCount("core-count",
llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
llvm::cl::init(-1));
llvm::cl::opt<size_t> dcpCriticalWindowSize(
"dcp-critical-window-size",
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
"Use 0 to run the legacy full-graph DCP analysis. Only used by the DCP scheduler."),
llvm::cl::init(4000));
llvm::cl::opt<bool>
ignoreConcatError("ignore-concat-error",
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
-2
View File
@@ -22,7 +22,6 @@ typedef enum {
typedef enum {
MergeSchedulerPeft = 0,
MergeSchedulerDcp = 1,
} PimMergeSchedulerType;
extern llvm::cl::OptionCategory OnnxMlirOptions;
@@ -36,7 +35,6 @@ extern llvm::cl::opt<bool> pimEmitJson;
extern llvm::cl::opt<size_t> crossbarSize;
extern llvm::cl::opt<size_t> crossbarCountInCore;
extern llvm::cl::opt<long> coresCount;
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
bool hasExplicitPimCoreCount();
void verifyExplicitPimCoreCount();
-4
View File
@@ -30,20 +30,17 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
if (pimEmissionTarget >= EmitSpatial) {
pm.addPass(createONNXToSpatialPass());
pm.addPass(createMergeComputeNodesPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Onnx lowered to Spatial"));
}
if (pimEmissionTarget >= EmitPim) {
pm.addPass(createSpatialToPimPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Spatial lowered to Pim"));
}
if (pimEmissionTarget >= EmitPimBufferized) {
pm.addPass(createPimBufferizationPass());
pm.addPass(createPimStaticMemoryCoalescingPass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim bufferized"));
}
@@ -54,7 +51,6 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
pm.addPass(createPimVerificationPass());
pm.addPass(createMessagePass("Pim verified"));
pm.addPass(createEmitPimCodePass());
// pm.addPass(createCountInstructionPass());
pm.addPass(createMessagePass("Pim code emitted"));
}
}
+22 -184
View File
@@ -1,9 +1,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/raw_ostream.h"
@@ -11,196 +9,42 @@
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace llvm;
using namespace mlir;
namespace onnx_mlir {
namespace {
namespace {} // namespace
struct DenseWeightView {
DenseElementsAttr denseAttr;
SmallVector<int64_t> shape;
SmallVector<int64_t> strides;
int64_t offset = 0;
};
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
SmallVector<Operation*> viewOps;
mlir::Value current = weight;
memref::GetGlobalOp getGlobalOp;
while (true) {
Operation* defOp = current.getDefiningOp();
if (!defOp)
return failure();
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
break;
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
if (!hasAllStaticSubviewParts(subview))
return failure();
viewOps.push_back(subview);
current = subview.getSource();
continue;
}
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
current = cast.getSource();
continue;
}
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
viewOps.push_back(collapse);
current = collapse.getSrc();
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
viewOps.push_back(expand);
current = expand.getSrc();
continue;
}
return failure();
}
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
if (!globalOp || !globalOp.getInitialValue())
return failure();
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
if (!denseAttr)
return failure();
DenseWeightView view;
view.denseAttr = denseAttr;
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
view.strides = computeRowMajorStrides(view.shape);
for (Operation* viewOp : llvm::reverse(viewOps)) {
if (auto subview = dyn_cast<memref::SubViewOp>(viewOp)) {
SmallVector<int64_t> nextStrides;
nextStrides.reserve(subview.getStaticStrides().size());
for (auto [offset, stride, sourceStride] :
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
view.offset += offset * sourceStride;
nextStrides.push_back(stride * sourceStride);
}
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
view.strides = std::move(nextStrides);
continue;
}
// Collapse/expand are accepted only as contiguous static reshapes of a
// dense global view, so a row-major stride recomputation preserves layout.
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return failure();
auto resultType = cast<MemRefType>(collapse.getResult().getType());
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
continue;
}
if (auto expand = dyn_cast<memref::ExpandShapeOp>(viewOp)) {
if (view.strides != computeRowMajorStrides(view.shape))
return failure();
auto resultType = cast<MemRefType>(expand.getResult().getType());
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
view.strides = computeRowMajorStrides(view.shape);
continue;
}
}
return view;
}
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
SmallVector<unsigned, 8> indices;
auto addIndex = [&](unsigned weightIndex) {
if (!llvm::is_contained(indices, weightIndex))
indices.push_back(weightIndex);
};
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
llvm::sort(indices);
return indices;
}
SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
return getUsedWeightIndices(coreOp.getBody().front());
}
SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
SmallVector<Operation*> coreLikeOps;
for (Operation& op : funcOp.getBody().front())
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
coreLikeOps.push_back(&op);
return coreLikeOps;
}
} // namespace
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>>
createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) {
auto coreWeightsDirPath = outputDirPath + "/weights";
auto error = sys::fs::create_directory(coreWeightsDirPath);
assert(!error && "Error creating weights directory");
size_t indexFileName = 0;
int64_t xbarSize = crossbarSize.getValue();
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>> mapCoreWeightToFileName;
llvm::SmallVector<std::pair<ResolvedWeightView, std::string>, 16> materializedWeights;
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
auto materializeWeight = [&](const ResolvedWeightView& weightView) -> std::string {
if (auto it = llvm::find_if(materializedWeights, [&](const auto& entry) { return entry.first == weightView; });
it != materializedWeights.end())
return it->second;
for (Operation* op : coreLikeOps) {
auto processCore = [&](pim::PimCoreOp coreOp) {
size_t coreId = static_cast<size_t>(coreOp.getCoreId());
for (unsigned index : getUsedWeightIndices(coreOp)) {
if (index >= coreOp.getWeights().size()) {
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
}
mlir::Value weight = coreOp.getWeights()[index];
auto globalOp = weightView.globalOp;
auto denseAttr = mlir::dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
assert(denseAttr && "Weight global must have dense initial value");
auto weightView = resolveDenseWeightView(moduleOp, weight);
if (failed(weightView)) {
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
}
if (mapCoreWeightToFileName[coreId].contains(weight))
continue;
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
auto& fileName = mapGlobalOpToFileName[globalOp];
mapCoreWeightToFileName[coreId].insert({weight, fileName});
continue;
}
DenseElementsAttr denseAttr = weightView->denseAttr;
ArrayRef<int64_t> shape = weightView->shape;
ArrayRef<int64_t> shape = weightView.shape;
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
int64_t numRows = shape[0];
int64_t numCols = shape[1];
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
@@ -215,7 +59,7 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
for (int64_t row = 0; row < xbarSize; row++) {
for (int64_t col = 0; col < xbarSize; col++) {
if (row < numRows && col < numCols) {
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
int64_t elementIndex = weightView.offset + row * weightView.strides[0] + col * weightView.strides[1];
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
uint64_t word = bits.getZExtValue();
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
@@ -227,23 +71,17 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
}
weightFileStream.close();
if (globalOp)
mapGlobalOpToFileName.insert({globalOp, newFileName});
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
}
return success();
materializedWeights.push_back({weightView, newFileName});
return newFileName;
};
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
(void) processCore(coreOp);
continue;
for (const WeightFileRequest& request : requests) {
auto& coreFiles = mapCoreWeightToFileName[request.coreId];
coreFiles.reserve(request.weights.size());
for (const ResolvedWeightView& weight : request.weights)
coreFiles.push_back(materializeWeight(weight));
}
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, processCore)))
return mapCoreWeightToFileName;
}
return mapCoreWeightToFileName;
}
+10 -3
View File
@@ -1,16 +1,23 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include <string>
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
namespace onnx_mlir {
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
createAndPopulateWeightFolder(mlir::func::FuncOp funcOp, llvm::StringRef outputDirPath);
struct WeightFileRequest {
size_t coreId = 0;
llvm::SmallVector<ResolvedWeightView, 8> weights;
};
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>>
createAndPopulateWeightFolder(llvm::ArrayRef<WeightFileRequest> requests, llvm::StringRef outputDirPath);
} // namespace onnx_mlir
@@ -3,11 +3,12 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(ONNXToSpatialIncGen)
add_pim_library(OMONNXToSpatial
ConversionPatterns.cpp
HostFoldability.cpp
HostLegality.cpp
PrePatterns.cpp
PostPatterns.cpp
Patterns.cpp
CompileTime.cpp
ONNXToSpatialVerifier.cpp
Patterns/Pre.cpp
Patterns/Post.cpp
Patterns/GeneratedConversion.cpp
Patterns/Math/Conv.cpp
Patterns/Math/Elementwise.cpp
Patterns/Math/Gemm.cpp
@@ -22,8 +23,11 @@ add_pim_library(OMONNXToSpatial
Patterns/Tensor/Resize.cpp
Patterns/Tensor/Reshape.cpp
Patterns/Tensor/Split.cpp
Patterns/Tensor/Transpose.cpp
ONNXToSpatialPass.cpp
Common/AttributeUtils.cpp
Common/ComputeRegionBuilder.cpp
Common/IndexingUtils.cpp
Common/ShapeTilingUtils.cpp
Common/WeightMaterialization.cpp
@@ -33,6 +37,7 @@ add_pim_library(OMONNXToSpatial
ONNXToSpatialIncGen
LINK_LIBS PUBLIC
MLIRLinalgDialect
MLIRSCFDialect
MLIRTosaDialect
OMCompilerOptions
@@ -0,0 +1,23 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "AttributeUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
int64_t getI64Attr(ArrayAttr attr, size_t index) { return cast<IntegerAttr>(attr[index]).getInt(); }
int64_t getOptionalI64Attr(std::optional<ArrayAttr> attr, size_t index, int64_t defaultValue) {
return attr ? getI64Attr(*attr, index) : defaultValue;
}
llvm::SmallVector<int64_t> getI64ArrayAttrValues(ArrayAttr attr) {
llvm::SmallVector<int64_t> values;
values.reserve(attr.size());
for (Attribute value : attr)
values.push_back(cast<IntegerAttr>(value).getInt());
return values;
}
} // namespace onnx_mlir
@@ -0,0 +1,18 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "llvm/ADT/SmallVector.h"
#include <cstddef>
#include <optional>
namespace onnx_mlir {
int64_t getI64Attr(mlir::ArrayAttr attr, size_t index);
int64_t getOptionalI64Attr(std::optional<mlir::ArrayAttr> attr, size_t index, int64_t defaultValue);
llvm::SmallVector<int64_t> getI64ArrayAttrValues(mlir::ArrayAttr attr);
} // namespace onnx_mlir
@@ -1,6 +1,8 @@
#pragma once
#include "AttributeUtils.hpp"
#include "ComputeRegionBuilder.hpp"
#include "IndexingUtils.hpp"
#include "ShapeTilingUtils.hpp"
#include "WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
@@ -1,5 +1,6 @@
#pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/ValueRange.h"
@@ -7,9 +8,11 @@
#include <cassert>
#include <cstddef>
#include <limits>
#include <type_traits>
#include <utility>
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
@@ -18,13 +21,17 @@ namespace detail {
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) {
return mlir::ValueRange(block->getArguments()).drop_front(weightCount);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(block->getArgument(Is)...);
}
template <typename Fn, size_t... Is>
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence<Is...>) {
return std::forward<Fn>(fn)(values[Is]...);
}
@@ -45,6 +52,13 @@ using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::
template <typename Fn>
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
struct SpatComputeBatchBodyArgs {
mlir::Value lane;
mlir::ValueRange weights;
mlir::ValueRange inputs;
mlir::ValueRange outputs;
};
} // namespace detail
template <typename RewriterT>
@@ -85,6 +99,8 @@ auto createSpatCompute(RewriterT& rewriter,
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value weight : weights)
block->addArgument(weight.getType(), loc);
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
@@ -93,14 +109,17 @@ auto createSpatCompute(RewriterT& rewriter,
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
if constexpr (std::is_same_v<BodyResult, void>) {
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
detail::invokeWithValues(std::forward<BodyFn>(body),
detail::getInputBlockArgs(block, weights.size()),
std::make_index_sequence<NumInputs> {});
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult =
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
auto bodyResult = detail::invokeWithValues(std::forward<BodyFn>(body),
detail::getInputBlockArgs(block, weights.size()),
std::make_index_sequence<NumInputs> {});
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
@@ -123,6 +142,8 @@ auto createSpatCompute(RewriterT& rewriter,
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
auto* block = new mlir::Block();
for (mlir::Value weight : weights)
block->addArgument(weight.getType(), loc);
for (mlir::Value input : inputs)
block->addArgument(input.getType(), loc);
@@ -131,13 +152,13 @@ auto createSpatCompute(RewriterT& rewriter,
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
rewriter.setInsertionPointAfter(computeOp);
return computeOp;
}
else {
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
auto bodyResult = std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(computeOp);
rewriter.eraseOp(computeOp);
@@ -148,6 +169,95 @@ auto createSpatCompute(RewriterT& rewriter,
}
}
template <typename RewriterT, typename BodyFn>
auto createSpatComputeBatch(RewriterT& rewriter,
mlir::Location loc,
mlir::TypeRange resultTypes,
int64_t laneCount,
mlir::ValueRange weights,
mlir::ValueRange inputs,
BodyFn&& body) {
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
auto batchOp = spatial::SpatComputeBatch::create(
rewriter, loc, resultTypes, rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)), weights, inputs);
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
blockArgTypes.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
blockArgLocs.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
for (mlir::Value weight : weights) {
blockArgTypes.push_back(weight.getType());
blockArgLocs.push_back(weight.getLoc());
}
for (mlir::Value input : inputs) {
blockArgTypes.push_back(input.getType());
blockArgLocs.push_back(input.getLoc());
}
for (mlir::Type resultType : resultTypes) {
blockArgTypes.push_back(resultType);
blockArgLocs.push_back(loc);
}
auto* block =
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), mlir::TypeRange(blockArgTypes), blockArgLocs);
rewriter.setInsertionPointToStart(block);
detail::SpatComputeBatchBodyArgs args {
block->getArgument(0),
mlir::ValueRange(block->getArguments()).slice(1, weights.size()),
mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()),
mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size())};
using BodyResult = std::invoke_result_t<BodyFn, detail::SpatComputeBatchBodyArgs>;
if constexpr (std::is_same_v<BodyResult, void>) {
std::forward<BodyFn>(body)(args);
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
}
else {
auto bodyResult = std::forward<BodyFn>(body)(args);
if (mlir::failed(bodyResult)) {
rewriter.setInsertionPointAfter(batchOp);
rewriter.eraseOp(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
}
rewriter.setInsertionPointAfter(batchOp);
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
}
}
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
mlir::Value dest,
mlir::ArrayRef<mlir::OpFoldResult> offsets,
mlir::ArrayRef<mlir::OpFoldResult> sizes,
mlir::ArrayRef<mlir::OpFoldResult> strides) {
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
mlir::tensor::ParallelInsertSliceOp::create(rewriter, loc, source, dest, offsets, sizes, strides);
}
template <typename BodyFn>
mlir::Value materializeOrComputeUnary(mlir::Value input,
mlir::RankedTensorType resultType,
mlir::PatternRewriter& rewriter,
mlir::Location loc,
BodyFn&& build) {
auto&& buildFn = build;
if (isCompileTimeComputable(input))
return buildFn(input);
auto computeOp = createSpatCompute<1>(
rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) {
mlir::Value result = buildFn(computeInput);
spatial::SpatYieldOp::create(rewriter, loc, result);
});
return computeOp.getResult(0);
}
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
} // namespace onnx_mlir
@@ -0,0 +1,45 @@
#include <algorithm>
#include "IndexingUtils.hpp"
using namespace mlir;
namespace onnx_mlir {
int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank) {
int64_t normalizedAxis = normalizeAxis(axis, rank);
if (normalizedAxis < 0 || normalizedAxis >= rank)
return failure();
return normalizedAxis;
}
int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
static SmallVector<int64_t> normalizeAxesImpl(std::optional<ArrayAttr> axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes;
if (!axesAttr) {
normalizedAxes.reserve(rank);
for (int64_t axis = 0; axis < rank; ++axis)
normalizedAxes.push_back(axis);
}
else {
normalizedAxes.reserve(axesAttr->size());
for (Attribute attr : *axesAttr)
normalizedAxes.push_back(normalizeAxis(cast<IntegerAttr>(attr).getInt(), rank));
llvm::sort(normalizedAxes);
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
}
return normalizedAxes;
}
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes = normalizeAxesImpl(axesAttr, rank);
for (int64_t axis : normalizedAxes)
if (axis < 0 || axis >= rank)
return failure();
return normalizedAxes;
}
} // namespace onnx_mlir
@@ -0,0 +1,20 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/SmallVector.h"
#include <optional>
namespace onnx_mlir {
int64_t normalizeAxis(int64_t axis, int64_t rank);
mlir::FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank);
int64_t normalizeIndex(int64_t index, int64_t dimSize);
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
} // namespace onnx_mlir
@@ -3,26 +3,103 @@
#include "llvm/ADT/SmallVector.h"
#include <functional>
#include "IndexingUtils.hpp"
#include "ShapeTilingUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
bool hasStaticPositiveShape(RankedTensorType type) {
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
}
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
SmallVector<int64_t> permutedShape;
permutedShape.reserve(permutation.size());
for (int64_t axis : permutation)
permutedShape.push_back(shape[axis]);
return permutedShape;
}
SmallVector<int64_t> invertPermutation(ArrayRef<int64_t> permutation) {
SmallVector<int64_t> inversePermutation(permutation.size());
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
return inversePermutation;
}
FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<ArrayAttr> permAttr, int64_t rank) {
SmallVector<int64_t> permutation;
if (!permAttr) {
permutation.reserve(rank);
for (int64_t dim = rank - 1; dim >= 0; --dim)
permutation.push_back(dim);
return permutation;
}
if (static_cast<int64_t>(permAttr->size()) != rank)
return failure();
permutation.reserve(permAttr->size());
SmallVector<bool> seen(rank, false);
for (IntegerAttr attr : permAttr->getAsRange<IntegerAttr>()) {
int64_t axis = attr.getInt();
if (axis < 0 || axis >= rank || seen[axis])
return failure();
seen[axis] = true;
permutation.push_back(axis);
}
return permutation;
}
Value transposeMaybeInCompute(
Value value, RankedTensorType resultType, ArrayRef<int64_t> permutation, PatternRewriter& rewriter, Location loc) {
auto buildTranspose = [&](Value input) -> Value {
return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult();
};
return materializeOrComputeUnary(value, resultType, rewriter, loc, buildTranspose);
}
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
}
SmallVector<OpFoldResult> getZeroOffsets(PatternRewriter& rewriter, int64_t rank) {
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(0));
}
SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int64_t> shape) {
SmallVector<OpFoldResult> sizes;
sizes.reserve(shape.size());
for (int64_t dim : shape)
sizes.push_back(rewriter.getIndexAttr(dim));
return sizes;
}
SmallVector<Value> sliceTensor(
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
ArrayRef<long> shape = getTensorShape(tensorToSlice);
assert("Invalid axis" && axis < shape.size());
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
sizes.reserve(shape.size());
for (const auto size : shape)
sizes.push_back(rewriter.getIndexAttr(size));
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, shape.size());
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, shape);
sizes[axis] = rewriter.getIndexAttr(sliceSize);
long length = shape[axis];
@@ -44,7 +121,7 @@ SmallVector<Value> sliceTensor(
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
Value slice;
if (isHostFoldableValue(tensorToSlice)) {
if (isCompileTimeComputable(tensorToSlice)) {
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
}
else {
@@ -80,47 +157,33 @@ sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewri
return slicesPerCore;
}
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc) {
assert("Not a matrix" && isMatrixShape(getTensorShape(matrixToTile)));
Value extractAxisSlice(
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
auto sourceType = cast<RankedTensorType>(source.getType());
SmallVector<int64_t> resultShape(sourceType.getShape());
resultShape[axis] = size;
auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding());
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tiles;
SmallVector<Value> hSlices = sliceTensor(matrixToTile, 1, hSliceSize, rewriter, loc);
size_t numHSlices = hSlices.size();
for (size_t hSliceId = 0; hSliceId < numHSlices; hSliceId++) {
Value hSlice = hSlices[hSliceId];
SmallVector<Value> vSlices = sliceTensor(hSlice, 0, vSliceSize, rewriter, loc);
for (size_t vSliceId = 0; vSliceId < vSlices.size(); vSliceId++) {
size_t coreId = vSliceId / crossbarCountInCore;
Value vSlice = vSlices[vSliceId];
tiles[hSliceId][coreId].push_back(vSlice);
}
}
return tiles;
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, sourceType.getRank());
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(size);
return tensor::ExtractSliceOp::create(
rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
.getResult();
}
Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
Type elementType = oldType.getElementType();
int64_t shape[2] = {1, length};
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
auto buildBroadcast = [&](Value input) -> Value {
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
SmallVector<Value> index(oldType.getRank(), zero);
auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult();
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
};
if (isHostFoldableValue(scalarToBroadcast))
return buildBroadcast(scalarToBroadcast);
auto broadcastCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input));
});
return broadcastCompute.getResult(0);
Value insertStaticSlice(
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
auto sourceType = cast<RankedTensorType>(source.getType());
return tensor::InsertSliceOp::create(rewriter,
loc,
source,
dest,
offsets,
getStaticSizes(rewriter, sourceType.getShape()),
getUnitStrides(rewriter, sourceType.getRank()))
.getResult();
}
} // namespace onnx_mlir
@@ -3,6 +3,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/ArrayRef.h"
@@ -11,46 +12,12 @@
#include <cassert>
#include <cstddef>
#include <optional>
#include <type_traits>
#include <utility>
namespace onnx_mlir {
template <class ShapedType>
inline auto getImageWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getImageHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getImageChannel(const ShapedType& shapedType) {
return shapedType.getDimSize(1);
}
template <class ShapedType>
inline auto getImageN(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
template <class ShapedType>
inline auto getKernelWidth(const ShapedType& shapedType) {
return shapedType.getDimSize(2);
}
template <class ShapedType>
inline auto getKernelHeight(const ShapedType& shapedType) {
return shapedType.getDimSize(3);
}
template <class ShapedType>
inline auto getFilterCount(const ShapedType& shapedType) {
return shapedType.getDimSize(0);
}
using HSliceId = size_t;
using CoreId = size_t;
@@ -87,17 +54,6 @@ bool isHVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[0] == 1;
}
template <class T>
bool isVVectorShape(mlir::ArrayRef<T> shape) {
return shape.size() == 2 && shape[1] == 1;
}
template <class T>
T getVectorLength(mlir::ArrayRef<T> shape) {
assert(isVectorShape(shape));
return shape[0] != 1 ? shape[0] : shape[1];
}
inline auto getTensorShape(mlir::Value tensor) {
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
}
@@ -109,6 +65,31 @@ inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
&& lhsType.getShape() == rhsType.getShape();
}
bool hasStaticPositiveShape(mlir::ArrayRef<int64_t> shape);
bool hasStaticPositiveShape(mlir::RankedTensorType type);
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
int64_t rank);
mlir::Value transposeMaybeInCompute(mlir::Value value,
mlir::RankedTensorType resultType,
mlir::ArrayRef<int64_t> permutation,
mlir::PatternRewriter& rewriter,
mlir::Location loc);
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, mlir::ArrayRef<int64_t> shape);
/// Slices a statically shaped tensor along one axis into contiguous pieces of
/// at most `sliceSize` elements.
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
@@ -127,18 +108,13 @@ llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
/// Tiles a matrix first across output columns and then across input rows so it
/// can be assigned to crossbars grouped by core.
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
tileMatrix(mlir::Value& matrixToTile,
int64_t hSliceSize,
int64_t vSliceSize,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location& loc);
mlir::Value extractAxisSlice(
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
int64_t length,
mlir::ConversionPatternRewriter& rewriter,
mlir::Location loc);
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
mlir::Location loc,
mlir::Value source,
mlir::Value dest,
llvm::ArrayRef<mlir::OpFoldResult> offsets);
} // namespace onnx_mlir
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
@@ -43,8 +44,8 @@ bool isWeightLikeComputeOperand(Value value) {
value = collapseShapeOp.getSrc();
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
value = transposeOp.getData();
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
value = transposeOp.getInput();
continue;
}
@@ -80,7 +81,7 @@ FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr
return referencedValue.getResult();
}
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(definingOp))
return failure();
IRMapping localMapper;
@@ -1,4 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -8,7 +9,11 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include <utility>
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -23,8 +28,7 @@ static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
}
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
return llvm::all_of(extractOp.getIndices(),
[](Value index) { return isa_and_nonnull<arith::ConstantIndexOp>(index.getDefiningOp()); });
return llvm::all_of(extractOp.getIndices(), [](Value index) { return matchConstantIndexValue(index).has_value(); });
}
static bool isStaticTensorResult(Operation* op) {
@@ -34,13 +38,6 @@ static bool isStaticTensorResult(Operation* op) {
});
}
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
SmallVector<int64_t> strides(shape.size(), 1);
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
strides[dim] = strides[dim + 1] * shape[dim + 1];
return strides;
}
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!tensorType)
@@ -145,7 +142,7 @@ static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
return nullptr;
}
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
auto* definingOp = value.getDefiningOp();
if (!definingOp || !visited.insert(definingOp).second)
return nullptr;
@@ -156,7 +153,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
return denseAttr;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getData(), visited);
if (!inputAttr)
return nullptr;
@@ -168,8 +165,18 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
}
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getInput(), visited);
if (!inputAttr)
return nullptr;
SmallVector<int64_t> perm(transposeOp.getPermutation().begin(), transposeOp.getPermutation().end());
auto transposedAttr = transposeDenseElements(inputAttr, perm);
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
}
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
@@ -177,7 +184,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
}
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
auto inputAttr = getHostConstantDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
if (!inputAttr)
return nullptr;
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
@@ -185,7 +192,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
}
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
auto inputAttr = getHostConstantDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
if (!inputAttr)
return nullptr;
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
@@ -195,62 +202,98 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
return nullptr;
}
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
if (!op || !visited.insert(op).second)
return false;
static std::optional<CompileTimeSource>
getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited, size_t chainLength = 0) {
if (!op)
return std::nullopt;
if (!visited.insert(op).second)
return {
{op, chainLength}
};
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
return true;
return {
{op, chainLength}
};
chainLength += 1;
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
return hasConstantIndices(extractOp)
? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength)
: std::nullopt;
if (!isStaticTensorResult(op))
return false;
return std::nullopt;
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
return isHostFoldableValue(transposeOp.getData());
return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength);
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(op))
return getCompileTimeSourceImpl(transposeOp.getInput().getDefiningOp(), visited, chainLength);
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
return isHostFoldableValue(collapseShapeOp.getSrc());
return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength);
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
return isHostFoldableValue(expandShapeOp.getSrc());
return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength);
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
return hasStaticUnitStrides(extractSliceOp)
? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength)
: std::nullopt;
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
return isHostFoldableValue(splatOp.getInput());
return getCompileTimeSourceImpl(splatOp.getInput().getDefiningOp(), visited, chainLength);
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
return isHostFoldableValue(extractRowsOp.getInput());
return getCompileTimeSourceImpl(extractRowsOp.getInput().getDefiningOp(), visited, chainLength);
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op))
return llvm::all_of(concatOp.getInputs(), isHostFoldableValue);
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
std::optional<CompileTimeSource> res = {};
for (auto operandValue : concatOp.getOperands()) {
auto partialRes = getCompileTimeSourceImpl(operandValue.getDefiningOp(), visited, chainLength);
if (!partialRes)
return std::nullopt;
return false;
if (!res) {
res = partialRes;
continue;
}
if (res->chainLength < partialRes->chainLength)
res = partialRes;
}
return res;
}
return std::nullopt;
}
} // namespace
bool isHostFoldableValue(Value value) {
std::optional<CompileTimeSource> getCompileTimeSource(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited;
return getCompileTimeSourceImpl(op, visited);
}
bool isCompileTimeComputable(Value value) {
auto* definingOp = value.getDefiningOp();
if (!definingOp)
return false;
llvm::SmallPtrSet<Operation*, 8> visited;
return isHostFoldableOpImpl(definingOp, visited);
return getCompileTimeSourceImpl(definingOp, visited).has_value();
}
bool isHostFoldableOp(Operation* op) {
bool isCompileTimeOp(Operation* op) {
llvm::SmallPtrSet<Operation*, 8> visited;
return isHostFoldableOpImpl(op, visited);
return getCompileTimeSourceImpl(op, visited).has_value();
}
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
DenseElementsAttr getHostConstDenseElementsAttr(Value value) {
llvm::SmallPtrSet<Operation*, 8> visited;
return getHostFoldableDenseElementsAttrImpl(value, visited);
return getHostConstantDenseElementsAttrImpl(value, visited);
}
} // namespace onnx_mlir
@@ -0,0 +1,22 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
struct CompileTimeSource {
mlir::Operation* source;
size_t chainLength;
};
std::optional<CompileTimeSource> getCompileTimeSource(mlir::Operation* op);
bool isCompileTimeComputable(mlir::Value value);
bool isCompileTimeOp(mlir::Operation* op);
mlir::DenseElementsAttr getHostConstDenseElementsAttr(mlir::Value value);
} // namespace onnx_mlir
@@ -1,15 +0,0 @@
#pragma once
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Value.h"
namespace onnx_mlir {
bool isHostFoldableValue(mlir::Value value);
bool isHostFoldableOp(mlir::Operation* op);
mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value);
} // namespace onnx_mlir
@@ -1,34 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
for (Operation& op : funcOp.getFunctionBody().front()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
continue;
if (isHostFoldableOp(&op))
continue;
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside "
"spat.compute");
});
}
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures");
return success(!diagnostics.hasFailure());
}
} // namespace onnx_mlir
@@ -1,10 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
namespace onnx_mlir {
mlir::LogicalResult verifyONNXToSpatialHostLegality(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -1,5 +1,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
@@ -12,13 +14,10 @@
#include "Common/Common.hpp"
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
@@ -44,7 +43,8 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
IRMapping mapper;
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
if (!computes.empty())
SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
if (!computes.empty() || !computeBatches.empty())
return;
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
@@ -85,30 +85,6 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
returnOp.setOperand(index, computeResult);
}
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
IRRewriter rewriter(funcOp.getContext());
Block& entryBlock = funcOp.getFunctionBody().front();
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
if (!transposeOp || isHostFoldableOp(transposeOp))
continue;
// Transpose stays globally legal because constant/view-only cases are
// allowed on the host. Any residual runtime transpose must be sunk into
// spat.compute before the host legality check.
auto resultType = transposeOp.getResult().getType();
rewriter.setInsertionPoint(transposeOp);
auto computeOp = createSpatCompute<1>(
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
Value transposed =
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
});
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
}
}
void ONNXToSpatialPass::runOnOperation() {
ModuleOp moduleOp = getOperation();
MLIRContext* ctx = &getContext();
@@ -116,7 +92,9 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget preTarget(*ctx);
preTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
@@ -154,10 +132,13 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget target(*ctx);
target.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
target.addIllegalOp<ONNXMatMulOp>();
target.addIllegalOp<ONNXTransposeOp>();
target.addIllegalOp<ONNXAddOp>();
target.addIllegalOp<ONNXDivOp>();
target.addIllegalOp<ONNXMulOp>();
@@ -184,22 +165,18 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
RewritePatternSet transposePatterns(ctx);
populateTransposePatterns(transposePatterns, ctx);
walkAndApplyPatterns(moduleOp, std::move(transposePatterns));
ConversionTarget earlyPostTarget(*ctx);
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
earlyPostTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
[](spatial::SpatComputeBatch batchOp) { return !requiresEarlyPostRewrite(batchOp); });
RewritePatternSet earlyPostPatterns(ctx);
populateEarlyPostPatterns(earlyPostPatterns, ctx);
if (failed(applyPartialConversion(*entryFunc, earlyPostTarget, std::move(earlyPostPatterns)))) {
moduleOp.emitError("failed to normalize single-lane spat.compute_batch ops before core assignment checks");
signalPassFailure();
return;
}
PassManager cleanupPM(ctx);
cleanupPM.addPass(createCanonicalizerPass());
@@ -211,7 +188,9 @@ void ONNXToSpatialPass::runOnOperation() {
ConversionTarget postTarget(*ctx);
postTarget.addLegalDialect<spatial::SpatialDialect,
ONNXDialect,
linalg::LinalgDialect,
tensor::TensorDialect,
affine::AffineDialect,
arith::ArithDialect,
scf::SCFDialect>();
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
@@ -227,9 +206,7 @@ void ONNXToSpatialPass::runOnOperation() {
return;
}
wrapTopLevelRuntimeTransposes(*entryFunc);
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
if (failed(verifyONNXToSpatial(*entryFunc))) {
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
signalPassFailure();
return;
@@ -0,0 +1,157 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LLVM.h"
#include "Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
func.walk([&](Operation* op) {
if (!hasWeightAlways(op))
return;
for (Value result : op->getResults()) {
if (hasOnlySpatialMvmVmmWeightUses(result))
continue;
diagnostics.report(op, [&](Operation* illegalOp) {
illegalOp->emitOpError(
"weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights");
});
return;
}
});
}
Region* getParentRegion(Value value) {
if (auto blockArg = dyn_cast<BlockArgument>(value))
return blockArg.getOwner()->getParent();
if (Operation* definingOp = value.getDefiningOp())
return definingOp->getParentRegion();
return nullptr;
}
bool isDefinedInsideRegion(Value value, Region& region) {
Region* parentRegion = getParentRegion(value);
return parentRegion && (&region == parentRegion || region.isAncestor(parentRegion));
}
bool isLegalHostBackedValue(Value value) {
Operation* definingOp = value.getDefiningOp();
if (!definingOp)
return isa<BlockArgument>(value);
if (isa<spatial::SpatChannelReceiveOp>(definingOp))
return false;
return definingOp->getDialect()->getNamespace() != "spat";
}
LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp,
ValueRange inputs,
bool allowChannelReceiveInputs,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
for (auto [inputIndex, input] : llvm::enumerate(inputs)) {
unsigned currentInputIndex = inputIndex;
Operation* definingOp = input.getDefiningOp();
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
continue;
if (isLegalHostBackedValue(input))
continue;
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError()
<< kind << " input #" << currentInputIndex
<< (allowChannelReceiveInputs ? " must come from the host or an explicit "
"spat.channel_receive"
: " must come from the host");
if (definingOp)
diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName();
});
return failure();
}
return success();
}
void verifyNoExternalTensorCaptures(Operation* ownerOp,
Region& region,
StringRef kind,
pim::CappedDiagnosticReporter& diagnostics) {
region.walk([&](Operation* op) {
for (OpOperand& operand : op->getOpOperands()) {
Value value = operand.get();
if (!isa<TensorType>(value.getType()))
continue;
if (isDefinedInsideRegion(value, region) || isa<BlockArgument>(value))
continue;
Operation* definingOp = value.getDefiningOp();
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
diagnostics.report(ownerOp, [&](Operation* illegalOp) {
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor "
<< "values";
diagnostic.attachNote(op->getLoc())
<< "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by "
<< (definingOp ? definingOp->getName().getStringRef() : StringRef("<block argument>"));
});
}
});
}
} // namespace
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
for (Operation& op : funcOp.getOps()) {
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
continue;
if (isCompileTimeOp(&op))
continue;
diagnostics.report(&op, [](Operation* illegalOp) {
illegalOp->emitOpError(
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
});
}
checkWeightUseChains(funcOp, diagnostics);
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
return success(!diagnostics.hasFailure());
}
LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) {
pim::CappedDiagnosticReporter diagnostics;
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
(void) verifyComputeLikeInputs(
computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics);
verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics);
}
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
(void) verifyComputeLikeInputs(computeBatchOp.getOperation(),
computeBatchOp.getInputs(),
/*allowChannelReceiveInputs=*/false,
"spat.compute_batch",
diagnostics);
verifyNoExternalTensorCaptures(
computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics);
}
diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed");
return success(!diagnostics.hasFailure());
}
} // namespace onnx_mlir
@@ -0,0 +1,11 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Support/LogicalResult.h"
namespace onnx_mlir {
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
mlir::LogicalResult verifySpatialCommunicationInvariants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -1,19 +1,14 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
} // namespace
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<removeLRN>(ctx);
void populatePrePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { populateGeneratedPrePatterns(patterns, ctx); }
void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateGeneratedConversionPatterns(patterns, ctx);
populateElementwisePatterns(patterns, ctx);
populateGemmPatterns(patterns, ctx);
populateConvPatterns(patterns, ctx);
@@ -27,6 +22,11 @@ void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRCon
populateResizePatterns(patterns, ctx);
populateReshapePatterns(patterns, ctx);
populateSplitPatterns(patterns, ctx);
populateTransposePatterns(patterns, ctx);
}
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
populateWeightPromotionPatterns(patterns, ctx);
}
} // namespace onnx_mlir
@@ -1,38 +1,39 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGeneratedConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateWeightPromotionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReduceMeanPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSigmoidPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSoftmaxPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
bool requiresPostRewrite(spatial::SpatCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -0,0 +1,18 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
} // namespace
void populateGeneratedConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<removeLRN>(ctx);
}
} // namespace onnx_mlir
@@ -11,7 +11,7 @@
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -28,8 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
ConversionPatternRewriter& rewriter) const override;
};
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
auto biasType = cast<RankedTensorType>(bias.getType());
if (biasType.getRank() != 1)
@@ -53,23 +51,100 @@ static Value createPaddedRows(Value tensorValue,
if (tensorType.getDimSize(0) == paddedRows)
return tensorValue;
auto paddedType = RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType());
auto paddedType = RankedTensorType::get(
{paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding());
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
rewriter.getIndexAttr(0)};
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads);
auto* padBlock = new Block();
for (int i = 0; i < 2; i++)
for (int i = 0; i < 2; ++i)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = arith::ConstantOp::create(
rewriter, loc, tensorType.getElementType(), rewriter.getZeroAttr(tensorType.getElementType()));
tensor::YieldOp::create(rewriter, loc, zero.getResult());
auto zero = getOrCreateConstant(
rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()), tensorType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
static Value packRowsForParallelGemm(
Value rows, RankedTensorType rowsType, int64_t packFactor, ConversionPatternRewriter& rewriter, Location loc) {
if (packFactor == 1)
return rows;
const int64_t packedNumRows = ceilIntegerDivide(rowsType.getDimSize(0), packFactor);
const int64_t paddedNumRows = packedNumRows * packFactor;
const int64_t rowWidth = rowsType.getDimSize(1);
auto groupedType =
RankedTensorType::get({packedNumRows, packFactor, rowWidth}, rowsType.getElementType(), rowsType.getEncoding());
auto packedType =
RankedTensorType::get({packedNumRows, packFactor * rowWidth}, rowsType.getElementType(), rowsType.getEncoding());
Value paddedRows = createPaddedRows(rows, rowsType, paddedNumRows, rewriter, loc);
Value groupedRows = tensor::ExpandShapeOp::create(rewriter,
loc,
groupedType,
paddedRows,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
return tensor::CollapseShapeOp::create(rewriter,
loc,
packedType,
groupedRows,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
}
static Value unpackRowsFromParallelGemm(Value packedRows,
RankedTensorType packedRowsType,
int64_t unpackedRows,
int64_t rowWidth,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
if (packFactor == 1)
return packedRows;
const int64_t packedNumRows = packedRowsType.getDimSize(0);
const int64_t paddedNumRows = packedNumRows * packFactor;
auto expandedType = RankedTensorType::get(
{packedNumRows, packFactor, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
auto paddedType =
RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
auto unpackedType =
RankedTensorType::get({unpackedRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
Value expandedRows = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
packedRows,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
Value paddedRows = tensor::CollapseShapeOp::create(rewriter,
loc,
paddedType,
expandedRows,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
if (paddedNumRows == unpackedRows)
return paddedRows;
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(unpackedRows), rewriter.getIndexAttr(rowWidth)};
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
return tensor::ExtractSliceOp::create(rewriter, loc, unpackedType, paddedRows, offsets, sizes, strides);
}
static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
Value wTrans,
RankedTensorType wType,
@@ -108,7 +183,30 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
}
auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues);
return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr);
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType);
}
static Value createConvWeightMatrix(
Value w, RankedTensorType wFlatType, RankedTensorType wTransType, ConversionPatternRewriter& rewriter, Location loc) {
auto buildWeightMatrix = [&](Value weight) -> Value {
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc,
wFlatType,
weight,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
return ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0})).getResult();
};
if (isCompileTimeComputable(w))
return buildWeightMatrix(w);
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {wTransType}, {}, ValueRange {w}, [&](Value weight) {
spatial::SpatYieldOp::create(rewriter, loc, buildWeightMatrix(weight));
});
return computeOp.getResult(0);
}
static Value buildPackedBias(bool hasBias,
@@ -134,7 +232,7 @@ static Value buildPackedBias(bool hasBias,
auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType());
auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues);
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedBiasAttr, packedBiasType);
}
static Value createIm2colRowComputes(Value x,
@@ -165,7 +263,6 @@ static Value createIm2colRowComputes(Value x,
Location loc) {
auto elemType = xType.getElementType();
constexpr size_t numInputs = 1;
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
auto im2colComputeOp =
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
Value paddedInput = xArg;
@@ -190,8 +287,8 @@ static Value createIm2colRowComputes(Value x,
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
tensor::YieldOp::create(rewriter, loc, zero.getResult());
auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getFloatAttr(elemType, 0.0), elemType);
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
paddedInput = padOp.getResult();
}
@@ -199,13 +296,14 @@ static Value createIm2colRowComputes(Value x,
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
// until the late PIM unrolling step.
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
auto c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
auto c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
auto cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, numPatches);
auto cNumPatchesPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, numPatchesPerBatch);
auto cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, outWidth);
auto cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight);
auto cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth);
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
rewriter.setInsertionPointToStart(im2colLoop.getBody());
@@ -252,28 +350,8 @@ static Value createIm2colRowComputes(Value x,
Value im2col = im2colLoop.getResult(0);
Value gemmInputRows = im2col;
if (packFactor != 1) {
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
loc,
groupedType,
paddedIm2col,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
loc,
packedType,
groupedIm2col,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
}
if (packFactor != 1)
gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc);
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
});
@@ -291,41 +369,20 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
int64_t packFactor,
ConversionPatternRewriter& rewriter,
Location loc) {
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
const int64_t paddedNumPatches = packedNumRows * packFactor;
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
Value gemmOut;
if (packFactor == 1) {
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
}
else {
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
loc,
expandedType,
packedOutput,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
loc,
paddedType,
expandedOutput,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
gemmOut = paddedOutput;
if (paddedNumPatches != numPatches) {
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
}
gemmOut = unpackRowsFromParallelGemm(packedOutput,
cast<RankedTensorType>(packedOutput.getType()),
numPatches,
numChannelsOut,
packFactor,
rewriter,
loc);
}
// Restore to NCHW layout:
@@ -391,19 +448,11 @@ static Value lowerSingleConvGroup(Value x,
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
auto wDenseAttr = getHostConstDenseElementsAttr(w);
// Prepare weight matrix W for crossbar storage:
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
loc,
wFlatType,
w,
SmallVector<ReassociationIndices> {
{0},
{1, 2, 3}
});
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
Value wTrans = createConvWeightMatrix(w, wFlatType, wTransType, rewriter, loc);
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
@@ -412,7 +461,7 @@ static Value lowerSingleConvGroup(Value x,
DenseElementsAttr biasDenseAttr;
if (hasB) {
gemmBias = b;
biasDenseAttr = getHostFoldableDenseElementsAttr(b);
biasDenseAttr = getHostConstDenseElementsAttr(b);
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
}
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
@@ -597,10 +646,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
return failure();
}
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
const int64_t strideHeight = getOptionalI64Attr(stridesAttr, 0, 1);
const int64_t strideWidth = getOptionalI64Attr(stridesAttr, 1, 1);
const int64_t dilationHeight = getOptionalI64Attr(dilationsAttr, 0, 1);
const int64_t dilationWidth = getOptionalI64Attr(dilationsAttr, 1, 1);
int64_t padHeightBegin = 0;
int64_t padHeightEnd = 0;
@@ -608,10 +657,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
int64_t padWidthEnd = 0;
if (padsAttr) {
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
padHeightBegin = getI64Attr(*padsAttr, 0);
padWidthBegin = getI64Attr(*padsAttr, 1);
padHeightEnd = getI64Attr(*padsAttr, 2);
padWidthEnd = getI64Attr(*padsAttr, 3);
}
else {
// Compute padding from auto_pad attribute
@@ -717,7 +766,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
}
Value result;
if (llvm::all_of(groupResults, isHostFoldableValue)) {
if (llvm::all_of(groupResults, isCompileTimeComputable)) {
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
}
else {
@@ -7,7 +7,7 @@
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -83,7 +83,7 @@ static FailureOr<Value> materializeBroadcastedConstantTensor(Value value,
}
auto broadcastedAttr = DenseElementsAttr::get(resultType, resultValues);
return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult();
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), broadcastedAttr, resultType);
}
static FailureOr<Value>
@@ -121,7 +121,7 @@ static FailureOr<Value> materializeReciprocalTensor(Value value,
}
auto reciprocalAttr = DenseFPElementsAttr::get(resultType, reciprocalValues);
return arith::ConstantOp::create(rewriter, loc, resultType, reciprocalAttr).getResult();
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), reciprocalAttr, resultType);
}
template <typename OnnxOp, typename SpatialOp>
File diff suppressed because it is too large Load Diff
@@ -1,3 +1,5 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
@@ -5,12 +7,10 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include <functional>
#include <numeric>
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -19,14 +19,6 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
}
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
ArrayRef<int64_t> rhsBatchShape) {
if (lhsBatchShape.empty())
@@ -38,62 +30,60 @@ static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
}
static Value collapseBatchDims(Value value,
int64_t batchSize,
int64_t rows,
int64_t cols,
PatternRewriter& rewriter,
Location loc) {
static Value
collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 2 || type.getRank() == 3)
return value;
auto collapsedType =
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
auto collapsedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
};
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}};
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
reassociation.front().push_back(dim);
auto buildCollapsed = [&](Value input) -> Value {
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
};
if (isHostFoldableValue(value))
return buildCollapsed(value);
auto collapseCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
});
return collapseCompute.getResult(0);
return materializeOrComputeUnary(value, collapsedType, rewriter, loc, buildCollapsed);
}
static Value expandBatchDims(Value value,
RankedTensorType outputType,
size_t batchRank,
PatternRewriter& rewriter,
Location loc) {
static Value
expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, PatternRewriter& rewriter, Location loc) {
if (cast<RankedTensorType>(value.getType()) == outputType)
return value;
SmallVector<ReassociationIndices> reassociation = {
ReassociationIndices {},
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
ReassociationIndices {static_cast<int64_t>(batchRank)},
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
};
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}};
for (size_t dim = 0; dim < batchRank; ++dim)
reassociation.front().push_back(static_cast<int64_t>(dim));
auto expandCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
spatial::SpatYieldOp::create(rewriter, loc, expanded);
auto buildExpanded = [&](Value input) -> Value {
return tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation).getResult();
};
return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded);
}
static Value ensureBatchedTensor(
Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
auto type = cast<RankedTensorType>(value.getType());
if (type.getRank() == 3)
return value;
auto batchedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
auto buildExpanded = [&](Value input) -> Value {
return tensor::ExpandShapeOp::create(rewriter,
loc,
batchedType,
input,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
return expandCompute.getResult(0);
};
return materializeOrComputeUnary(value, batchedType, rewriter, loc, buildExpanded);
}
static Value extractBatchMatrix(Value value,
@@ -112,7 +102,7 @@ static Value extractBatchMatrix(Value value,
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes = {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, 3);
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
auto buildMatrix = [&](Value input) -> Value {
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
@@ -126,14 +116,7 @@ static Value extractBatchMatrix(Value value,
});
};
if (isHostFoldableValue(value))
return buildMatrix(value);
auto batchMatrixCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
});
return batchMatrixCompute.getResult(0);
return materializeOrComputeUnary(value, matrixType, rewriter, loc, buildMatrix);
}
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
@@ -150,18 +133,7 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
perm = {0, 2, 1};
}
auto buildTranspose = [&](Value input) -> Value {
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
};
if (isHostFoldableValue(value))
return buildTranspose(value);
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
});
return transposeCompute.getResult(0);
return transposeMaybeInCompute(value, transposedType, perm, rewriter, loc);
}
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
@@ -185,28 +157,527 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite
return transposeCompute.getResult(0);
}
static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
auto sourceType = cast<RankedTensorType>(value.getType());
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> highPads;
highPads.reserve(sourceType.getRank());
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
auto* padBlock = new Block();
for (int64_t i = 0; i < sourceType.getRank(); ++i)
padBlock->addArgument(rewriter.getIndexType(), loc);
padOp.getRegion().push_back(padBlock);
rewriter.setInsertionPointToStart(padBlock);
auto zero = getOrCreateConstant(
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
tensor::YieldOp::create(rewriter, loc, zero);
rewriter.setInsertionPointAfter(padOp);
return padOp.getResult();
}
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
using OpRewritePattern::OpRewritePattern;
static Value createPaddedBatchedInputCompute(Value input,
RankedTensorType paddedInputType,
PatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
if (inputType == paddedInputType)
return input;
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
});
return computeOp.getResult(0);
}
static FailureOr<Value> materializePaddedBatchedWeight(
Value value, int64_t sourceBatch, int64_t targetBatch, RankedTensorType resultType, PatternRewriter& rewriter) {
auto sourceType = cast<RankedTensorType>(value.getType());
if (sourceType == resultType)
return value;
auto denseAttr = getHostConstDenseElementsAttr(value);
if (!denseAttr)
return failure();
const int64_t sourceRows = sourceType.getRank() == 2 ? sourceType.getDimSize(0) : sourceType.getDimSize(1);
const int64_t sourceCols = sourceType.getRank() == 2 ? sourceType.getDimSize(1) : sourceType.getDimSize(2);
const int64_t targetRows = resultType.getDimSize(1);
const int64_t targetCols = resultType.getDimSize(2);
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> resultValues(resultType.getNumElements(), rewriter.getZeroAttr(resultType.getElementType()));
for (int64_t batchIdx = 0; batchIdx < targetBatch; ++batchIdx) {
const int64_t sourceBatchIdx = sourceType.getRank() == 2 ? 0 : (sourceBatch == 1 ? 0 : batchIdx);
const int64_t sourceBatchBase = sourceType.getRank() == 2 ? 0 : sourceBatchIdx * sourceRows * sourceCols;
const int64_t targetBatchBase = batchIdx * targetRows * targetCols;
for (int64_t row = 0; row < sourceRows; ++row)
for (int64_t col = 0; col < sourceCols; ++col)
resultValues[targetBatchBase + row * targetCols + col] = sourceValues[sourceBatchBase + row * sourceCols + col];
}
auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
}
static Value extractBatchedATile(Value a,
int64_t sourceBatchCount,
Value batch,
Value row,
Value kOffset,
RankedTensorType aTileType,
PatternRewriter& rewriter,
Location loc) {
auto aSliceType = RankedTensorType::get({1, 1, aTileType.getDimSize(1)}, aTileType.getElementType());
SmallVector<OpFoldResult> offsets {
sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), row, kOffset};
SmallVector<OpFoldResult> sizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(aTileType.getDimSize(1))};
auto slice =
tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, getUnitStrides(rewriter, 3));
return tensor::CollapseShapeOp::create(rewriter,
loc,
aTileType,
slice,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
}
static Value extractBatchedBTile(Value b,
int64_t sourceBatchCount,
Value batch,
Value kOffset,
Value hOffset,
RankedTensorType bTileType,
PatternRewriter& rewriter,
Location loc) {
auto bSliceType =
RankedTensorType::get({1, bTileType.getDimSize(0), bTileType.getDimSize(1)}, bTileType.getElementType());
SmallVector<OpFoldResult> offsets {
sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), kOffset, hOffset};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1),
rewriter.getIndexAttr(bTileType.getDimSize(0)),
rewriter.getIndexAttr(bTileType.getDimSize(1))};
auto slice =
tensor::ExtractSliceOp::create(rewriter, loc, bSliceType, b, offsets, sizes, getUnitStrides(rewriter, 3));
return tensor::CollapseShapeOp::create(rewriter,
loc,
bTileType,
slice,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
}
static Value getBatchLaneIndex(
Value lane, int64_t numOutRows, int64_t numKSlices, int64_t numOutHSlices, PatternRewriter& rewriter, Location loc) {
return affineFloorDivConst(
rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices, rewriter.getInsertionBlock()->getParentOp());
}
static spatial::SpatComputeBatch createBatchedVmmBatch(Value a,
Value b,
RankedTensorType aType,
int64_t aBatchCount,
RankedTensorType bType,
int64_t bBatchCount,
RankedTensorType partialPiecesType,
int64_t numOutRows,
int64_t numKSlices,
int64_t numOutHSlices,
PatternRewriter& rewriter,
Location loc) {
const int64_t laneCount = partialPiecesType.getDimSize(0);
auto batchOp = createSpatComputeBatch(
rewriter,
loc,
TypeRange {partialPiecesType},
laneCount,
ValueRange {b},
ValueRange {a},
[&](detail::SpatComputeBatchBodyArgs args) {
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value row = affineModConst(rewriter, loc, args.lane, numOutRows, anchorOp);
Value outerLane = affineFloorDivConst(rewriter, loc, args.lane, numOutRows, anchorOp);
Value batch = getBatchLaneIndex(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
Value sliceLane = affineModConst(rewriter, loc, outerLane, numKSlices * numOutHSlices, anchorOp);
Value kSlice = affineModConst(rewriter, loc, sliceLane, numKSlices, anchorOp);
Value hSlice = affineFloorDivConst(rewriter, loc, sliceLane, numKSlices, anchorOp);
Value kOffset = affineMulConst(rewriter, loc, kSlice, crossbarSize.getValue(), anchorOp);
Value hOffset = affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), anchorOp);
auto aTileType =
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
auto bTileType = RankedTensorType::get(
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
bType.getElementType());
auto pieceType =
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
Value aTile =
extractBatchedATile(args.inputs.front(), aBatchCount, batch, row, kOffset, aTileType, rewriter, loc);
Value bTile =
extractBatchedBTile(args.weights.front(), bBatchCount, batch, kOffset, hOffset, bTileType, rewriter, loc);
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, getUnitStrides(rewriter, 2));
});
assert(succeeded(batchOp) && "expected batched MatMul VMM construction to succeed");
return *batchOp;
}
static Value extractDynamicBatchedBColumn(Value matrix,
int64_t sourceBatchCount,
Value batch,
Value column,
RankedTensorType vectorType,
PatternRewriter& rewriter,
Location loc) {
auto columnSliceType = RankedTensorType::get({1, vectorType.getDimSize(1), 1}, vectorType.getElementType());
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
: OpFoldResult(batch),
rewriter.getIndexAttr(0),
column};
SmallVector<OpFoldResult> sizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)};
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value columnSlice = tensor::ExtractSliceOp::create(rewriter, loc, columnSliceType, matrix, offsets, sizes, strides);
auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType());
Value collapsed = tensor::CollapseShapeOp::create(rewriter,
loc,
collapsedType,
columnSlice,
SmallVector<ReassociationIndices> {
{0, 1, 2}
})
.getResult();
return tensor::ExpandShapeOp::create(rewriter,
loc,
vectorType,
collapsed,
SmallVector<ReassociationIndices> {
{0, 1}
})
.getResult();
}
static Value extractDynamicBatchedBRow(Value matrix,
int64_t sourceBatchCount,
Value batch,
Value row,
RankedTensorType vectorType,
PatternRewriter& rewriter,
Location loc) {
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
: OpFoldResult(batch),
row,
rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
auto rowSlice =
tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3));
return tensor::CollapseShapeOp::create(rewriter,
loc,
vectorType,
rowSlice,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
}
static Value extractDynamicBatchedRowVector(Value matrix,
int64_t sourceBatchCount,
Value batch,
Value row,
RankedTensorType vectorType,
PatternRewriter& rewriter,
Location loc) {
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
: OpFoldResult(batch),
row,
rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
auto rowSlice =
tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3));
return tensor::CollapseShapeOp::create(rewriter,
loc,
vectorType,
rowSlice,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
}
static spatial::SpatComputeBatch createBatchedVvdmulBatch(Value a,
int64_t aBatchCount,
Value b,
int64_t bBatchCount,
RankedTensorType aType,
RankedTensorType bType,
RankedTensorType scalarPiecesType,
RankedTensorType outType,
bool bAlreadyTransposed,
PatternRewriter& rewriter,
Location loc) {
const int64_t numBatches = outType.getDimSize(0);
const int64_t numOutRows = outType.getDimSize(1);
const int64_t numOutCols = outType.getDimSize(2);
const int64_t reductionSize = aType.getDimSize(2);
const int64_t laneCount = numBatches * numOutRows * numOutCols;
auto batchOp = createSpatComputeBatch(
rewriter,
loc,
TypeRange {scalarPiecesType},
laneCount,
ValueRange {},
ValueRange {a, b},
[&](detail::SpatComputeBatchBodyArgs args) {
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value batch = affineFloorDivConst(rewriter, loc, args.lane, numOutRows * numOutCols, anchorOp);
Value batchLane = affineModConst(rewriter, loc, args.lane, numOutRows * numOutCols, anchorOp);
Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp);
Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp);
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
Value aVector =
extractDynamicBatchedRowVector(args.inputs[0], aBatchCount, batch, row, vectorType, rewriter, loc);
Value bVector =
bAlreadyTransposed
? extractDynamicBatchedBRow(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc)
: extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc);
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, getUnitStrides(rewriter, 2));
});
assert(succeeded(batchOp) && "expected batched MatMul VVDMul construction to succeed");
return *batchOp;
}
static Value createBatchedDynamicOutputCompute(Value scalarPieces,
RankedTensorType scalarPiecesType,
RankedTensorType outType,
PatternRewriter& rewriter,
Location loc) {
const int64_t laneCount = scalarPiecesType.getDimSize(0);
const int64_t numOutRows = outType.getDimSize(1);
const int64_t numOutCols = outType.getDimSize(2);
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
auto outputScalarType = RankedTensorType::get({1, 1, 1}, outType.getElementType());
auto computeOp =
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, ValueRange {scalarPieces}, [&](Value pieces) {
Value outputInit =
tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
rewriter.setInsertionPointToStart(loop.getBody());
Value lane = loop.getInductionVar();
Value outputAcc = loop.getRegionIterArgs().front();
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value batch = affineFloorDivConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp);
Value batchLane = affineModConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp);
Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp);
Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp);
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value scalar = tensor::ExtractSliceOp::create(
rewriter, loc, scalarType, pieces, scalarOffsets, scalarSizes, getUnitStrides(rewriter, 2));
Value expanded = tensor::ExpandShapeOp::create(rewriter,
loc,
outputScalarType,
scalar,
SmallVector<ReassociationIndices> {
{0},
{1, 2}
});
SmallVector<OpFoldResult> outputOffsets {batch, row, column};
SmallVector<OpFoldResult> outputSizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
scf::YieldOp::create(
rewriter,
loc,
tensor::InsertSliceOp::create(
rewriter, loc, expanded, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3))
.getResult());
rewriter.setInsertionPointAfter(loop);
spatial::SpatYieldOp::create(rewriter, loc, loop.getResult(0));
});
return computeOp.getResult(0);
}
static Value transposeBatchedOutput(Value value, RankedTensorType outputType, PatternRewriter& rewriter, Location loc) {
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, outputType, input, rewriter.getI64ArrayAttr({0, 2, 1}));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
return transposeCompute.getResult(0);
}
static Value extractBatchedReductionPiece(Value partialPiecesArg,
Value batch,
Value hSlice,
int64_t kSlice,
RankedTensorType pieceType,
int64_t numKSlices,
int64_t numOutHSlices,
int64_t numOutRows,
PatternRewriter& rewriter,
Location loc) {
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value batchOffset = affineMulConst(rewriter, loc, batch, numOutRows * numKSlices * numOutHSlices, anchorOp);
Value hOffset = affineMulConst(rewriter, loc, hSlice, numKSlices * numOutRows, anchorOp);
Value kOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice * numOutRows);
Value batchAndHSlice = arith::AddIOp::create(rewriter, loc, batchOffset, hOffset);
Value pieceOffset = arith::AddIOp::create(rewriter, loc, batchAndHSlice, kOffset);
SmallVector<OpFoldResult> offsets {pieceOffset, rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())};
return tensor::ExtractSliceOp::create(
rewriter, loc, pieceType, partialPiecesArg, offsets, sizes, getUnitStrides(rewriter, 2));
}
static Value reduceBatchedPartialPiecesForHSlice(Value partialPiecesArg,
Value batch,
Value hSlice,
RankedTensorType pieceType,
int64_t numKSlices,
int64_t numOutHSlices,
int64_t numOutRows,
PatternRewriter& rewriter,
Location loc) {
SmallVector<Value> activePieces;
activePieces.reserve(numKSlices);
for (int64_t kSlice = 0; kSlice < numKSlices; ++kSlice)
activePieces.push_back(extractBatchedReductionPiece(
partialPiecesArg, batch, hSlice, kSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, loc));
while (activePieces.size() > 1) {
SmallVector<Value> nextPieces;
nextPieces.reserve((activePieces.size() + 1) / 2);
for (size_t pieceIndex = 0; pieceIndex + 1 < activePieces.size(); pieceIndex += 2)
nextPieces.push_back(
spatial::SpatVAddOp::create(rewriter, loc, pieceType, activePieces[pieceIndex], activePieces[pieceIndex + 1])
.getResult());
if (activePieces.size() % 2 != 0)
nextPieces.push_back(activePieces.back());
activePieces = std::move(nextPieces);
}
return activePieces.front();
}
static Value createBatchedReductionCompute(Value partialPieces,
RankedTensorType partialPiecesType,
RankedTensorType outType,
RankedTensorType paddedOutType,
int64_t numBatches,
int64_t numKSlices,
PatternRewriter& rewriter,
Location loc) {
auto computeOp = createSpatCompute<1>(
rewriter, loc, TypeRange {outType}, {}, ValueRange {partialPieces}, [&](Value partialPiecesArg) {
const int64_t numOutRows = outType.getDimSize(1);
const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(2), crossbarSize.getValue());
auto pieceType = RankedTensorType::get({numOutRows, static_cast<int64_t>(crossbarSize.getValue())},
partialPiecesType.getElementType());
auto outputSliceType = RankedTensorType::get({1, numOutRows, static_cast<int64_t>(crossbarSize.getValue())},
partialPiecesType.getElementType());
Value outputInit =
tensor::EmptyOp::create(rewriter, loc, paddedOutType.getShape(), paddedOutType.getElementType()).getResult();
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
Value cNumBatches = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numBatches);
Value cNumOutHSlices =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cNumBatches, c1, ValueRange {outputInit});
rewriter.setInsertionPointToStart(batchLoop.getBody());
Value batch = batchLoop.getInductionVar();
Value batchAcc = batchLoop.getRegionIterArgs().front();
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cNumOutHSlices, c1, ValueRange {batchAcc});
rewriter.setInsertionPointToStart(hLoop.getBody());
Value hSlice = hLoop.getInductionVar();
Value outputAcc = hLoop.getRegionIterArgs().front();
Value reduced = reduceBatchedPartialPiecesForHSlice(
partialPiecesArg, batch, hSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, loc);
Value expandedReduced = tensor::ExpandShapeOp::create(rewriter,
loc,
outputSliceType,
reduced,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
Value hOffset =
affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp());
SmallVector<OpFoldResult> outputOffsets {batch, rewriter.getIndexAttr(0), hOffset};
SmallVector<OpFoldResult> outputSizes {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())};
scf::YieldOp::create(
rewriter,
loc,
tensor::InsertSliceOp::create(
rewriter, loc, expandedReduced, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3))
.getResult());
rewriter.setInsertionPointAfter(hLoop);
scf::YieldOp::create(rewriter, loc, hLoop.getResult(0));
rewriter.setInsertionPointAfter(batchLoop);
Value paddedOutput = batchLoop.getResult(0);
Value result = paddedOutput;
if (paddedOutType != outType) {
SmallVector<OpFoldResult> outputOffsets {
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(numBatches),
rewriter.getIndexAttr(outType.getDimSize(1)),
rewriter.getIndexAttr(outType.getDimSize(2))};
result = tensor::ExtractSliceOp::create(
rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, getUnitStrides(rewriter, 3));
}
spatial::SpatYieldOp::create(rewriter, loc, result);
});
return computeOp.getResult(0);
}
struct MatMulShapeInfo {
RankedTensorType lhsType;
RankedTensorType rhsType;
RankedTensorType outType;
SmallVector<int64_t> batchShape;
int64_t lhsBatch;
int64_t rhsBatch;
int64_t batch;
int64_t m;
int64_t k;
int64_t n;
};
static FailureOr<MatMulShapeInfo> analyzeMatMulShape(ONNXMatMulOp matmulOp) {
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
@@ -215,8 +686,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
return failure();
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
return failure();
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|| !haveStaticPositiveShape(outType.getShape()))
if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType))
return failure();
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
@@ -224,10 +694,10 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
if (failed(batchShape))
return failure();
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
@@ -246,30 +716,38 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
return failure();
}
Location loc = matmulOp.getLoc();
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
return MatMulShapeInfo {lhsType, rhsType, outType, *batchShape, lhsBatch, rhsBatch, batch, m, k, n};
}
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
int64_t lhsBatchForGemm = lhsBatch;
int64_t rhsBatchForGemm = rhsBatch;
int64_t gemmM = m;
int64_t gemmK = k;
int64_t gemmN = n;
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
auto shapeInfo = analyzeMatMulShape(matmulOp);
if (failed(shapeInfo) || shapeInfo->outType.getRank() != 2)
return failure();
Location loc = matmulOp.getLoc();
bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
Value lhs = collapseBatchDims(matmulOp.getA(), shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc);
Value rhs = collapseBatchDims(matmulOp.getB(), shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc);
int64_t lhsBatchForGemm = shapeInfo->lhsBatch;
int64_t rhsBatchForGemm = shapeInfo->rhsBatch;
int64_t gemmM = shapeInfo->m;
int64_t gemmK = shapeInfo->k;
int64_t gemmN = shapeInfo->n;
if (useTransposedForm) {
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
lhsBatchForGemm = rhsBatch;
lhsBatchForGemm = shapeInfo->rhsBatch;
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
rhsBatchForGemm = lhsBatch;
gemmM = n;
gemmN = m;
rhsBatchForGemm = shapeInfo->lhsBatch;
gemmM = shapeInfo->n;
gemmN = shapeInfo->m;
}
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
auto gemmType = RankedTensorType::get({gemmM, gemmN}, shapeInfo->outType.getElementType());
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
if (outType.getRank() == 2) {
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
Value gemmResult = ONNXGemmOp::create(rewriter,
@@ -285,8 +763,9 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
.getY();
if (useTransposedForm) {
auto transposeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) {
Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0}));
createSpatCompute<1>(rewriter, loc, TypeRange {shapeInfo->outType}, {}, gemmResult, [&](Value input) {
Value transposed =
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, input, rewriter.getI64ArrayAttr({1, 0}));
spatial::SpatYieldOp::create(rewriter, loc, transposed);
});
gemmResult = transposeCompute.getResult(0);
@@ -294,48 +773,115 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
rewriter.replaceOp(matmulOp, gemmResult);
return success();
}
};
SmallVector<Value> batchResults;
batchResults.reserve(batch);
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
Value gemmResult = ONNXGemmOp::create(rewriter,
loc,
gemmType,
lhsMatrix,
rhsMatrix,
none,
rewriter.getF32FloatAttr(1.0f),
rewriter.getF32FloatAttr(1.0f),
rewriter.getBoolAttr(false),
rewriter.getBoolAttr(false))
.getY();
auto batchResultCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) {
Value resultMatrix = input;
struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
auto shapeInfo = analyzeMatMulShape(matmulOp);
if (failed(shapeInfo))
return failure();
if (shapeInfo->outType.getRank() == 2)
return failure();
Location loc = matmulOp.getLoc();
bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
Value lhs = collapseBatchDims(matmulOp.getA(), shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc);
Value rhs = collapseBatchDims(matmulOp.getB(), shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc);
int64_t lhsBatchForGemm = shapeInfo->lhsBatch;
int64_t rhsBatchForGemm = shapeInfo->rhsBatch;
int64_t gemmM = shapeInfo->m;
int64_t gemmK = shapeInfo->k;
int64_t gemmN = shapeInfo->n;
if (useTransposedForm) {
resultMatrix = ONNXTransposeOp::create(rewriter,
loc,
RankedTensorType::get({m, n}, outType.getElementType()),
input,
rewriter.getI64ArrayAttr({1, 0}));
}
Value expanded = tensor::ExpandShapeOp::create(rewriter,
loc,
batchedOutType,
resultMatrix,
SmallVector<ReassociationIndices> {
{0, 1},
{2}
});
spatial::SpatYieldOp::create(rewriter, loc, expanded);
});
batchResults.push_back(batchResultCompute.getResult(0));
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
lhsBatchForGemm = shapeInfo->rhsBatch;
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
rhsBatchForGemm = shapeInfo->lhsBatch;
gemmM = shapeInfo->n;
gemmN = shapeInfo->m;
}
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
lhs = ensureBatchedTensor(lhs, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
rhs = ensureBatchedTensor(rhs, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
auto lhsBatchedType = cast<RankedTensorType>(lhs.getType());
auto rhsBatchedType = cast<RankedTensorType>(rhs.getType());
auto directOutType = RankedTensorType::get({shapeInfo->batch, gemmM, gemmN}, shapeInfo->outType.getElementType());
if (isCompileTimeComputable(rhs)) {
const int64_t numKSlices = ceilIntegerDivide(gemmK, crossbarSize.getValue());
const int64_t numOutHSlices = ceilIntegerDivide(gemmN, crossbarSize.getValue());
const int64_t paddedReductionSize = numKSlices * static_cast<int64_t>(crossbarSize.getValue());
const int64_t paddedOutCols = numOutHSlices * static_cast<int64_t>(crossbarSize.getValue());
auto paddedLhsType = RankedTensorType::get(
{lhsBatchForGemm, gemmM, paddedReductionSize}, lhsBatchedType.getElementType(), lhsBatchedType.getEncoding());
auto paddedRhsType = RankedTensorType::get({shapeInfo->batch, paddedReductionSize, paddedOutCols},
rhsBatchedType.getElementType(),
rhsBatchedType.getEncoding());
auto paddedOutType =
RankedTensorType::get({shapeInfo->batch, gemmM, paddedOutCols}, shapeInfo->outType.getElementType());
auto paddedRhs = materializePaddedBatchedWeight(rhs, rhsBatchForGemm, shapeInfo->batch, paddedRhsType, rewriter);
if (succeeded(paddedRhs)) {
Value paddedLhs = createPaddedBatchedInputCompute(lhs, paddedLhsType, rewriter, loc);
const int64_t laneCount = shapeInfo->batch * gemmM * numKSlices * numOutHSlices;
auto partialPiecesType = RankedTensorType::get({laneCount, static_cast<int64_t>(crossbarSize.getValue())},
shapeInfo->outType.getElementType());
auto batchOp = createBatchedVmmBatch(paddedLhs,
*paddedRhs,
paddedLhsType,
lhsBatchForGemm,
paddedRhsType,
rhsBatchForGemm,
partialPiecesType,
gemmM,
numKSlices,
numOutHSlices,
rewriter,
loc);
Value result = createBatchedReductionCompute(batchOp.getResult(0),
partialPiecesType,
directOutType,
paddedOutType,
shapeInfo->batch,
numKSlices,
rewriter,
loc);
if (useTransposedForm)
result = transposeBatchedOutput(
result,
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
rewriter,
loc);
result = expandBatchDims(result, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
rewriter.replaceOp(matmulOp, result);
return success();
}
}
const int64_t laneCount = shapeInfo->batch * gemmM * gemmN;
auto scalarPiecesType = RankedTensorType::get({laneCount, 1}, shapeInfo->outType.getElementType());
auto batchOp = createBatchedVvdmulBatch(lhs,
lhsBatchForGemm,
rhs,
rhsBatchForGemm,
lhsBatchedType,
rhsBatchedType,
scalarPiecesType,
directOutType,
false,
rewriter,
loc);
Value result =
createBatchedDynamicOutputCompute(batchOp.getResult(0), scalarPiecesType, directOutType, rewriter, loc);
if (useTransposedForm)
result = transposeBatchedOutput(
result,
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
rewriter,
loc);
result = expandBatchDims(result, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
rewriter.replaceOp(matmulOp, result);
return success();
}
@@ -344,7 +890,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
} // namespace
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.insert<MatMulToGemm>(ctx);
patterns.insert<MatMulToGemm, MatMulBatchedToSpatialComputes>(ctx);
}
} // namespace onnx_mlir
@@ -1,13 +1,16 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <numeric>
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -16,26 +19,6 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
SmallVector<int64_t> normalizedAxes;
if (!axesAttr) {
normalizedAxes.reserve(rank);
for (int64_t axis = 0; axis < rank; axis++)
normalizedAxes.push_back(axis);
return normalizedAxes;
}
normalizedAxes.reserve(axesAttr.size());
for (Attribute attr : axesAttr) {
int64_t axis = cast<IntegerAttr>(attr).getInt();
normalizedAxes.push_back(axis >= 0 ? axis : rank + axis);
}
llvm::sort(normalizedAxes);
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
return normalizedAxes;
}
static SmallVector<bool> buildReducedAxesMask(ArrayRef<int64_t> axes, int64_t rank) {
SmallVector<bool> reducedAxes(rank, false);
for (int64_t axis : axes) {
@@ -50,6 +33,184 @@ static RankedTensorType getAllOnesType(RankedTensorType inputType, Type elementT
return RankedTensorType::get(SmallVector<int64_t>(inputType.getRank(), 1), elementType);
}
static RankedTensorType getKeepdimsType(RankedTensorType inputType, Type elementType, ArrayRef<bool> reducedAxes) {
SmallVector<int64_t> shape;
shape.reserve(inputType.getRank());
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
shape.push_back(isReduced ? 1 : dim);
return RankedTensorType::get(shape, elementType, inputType.getEncoding());
}
static RankedTensorType getCompactKeptType(RankedTensorType inputType, Type elementType, ArrayRef<bool> reducedAxes) {
SmallVector<int64_t> shape;
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
if (!isReduced)
shape.push_back(dim);
return RankedTensorType::get(shape, elementType, inputType.getEncoding());
}
static RankedTensorType getReducedSliceType(RankedTensorType inputType, ArrayRef<bool> reducedAxes) {
SmallVector<int64_t> shape;
shape.reserve(inputType.getRank());
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
shape.push_back(isReduced ? dim : 1);
return RankedTensorType::get(shape, inputType.getElementType(), inputType.getEncoding());
}
static RankedTensorType getLanePackedKeepdimsType(int64_t laneCount, RankedTensorType leafType) {
SmallVector<int64_t> shape(leafType.getShape().begin(), leafType.getShape().end());
shape.front() = laneCount;
return RankedTensorType::get(shape, leafType.getElementType(), leafType.getEncoding());
}
static SmallVector<int64_t> getKeptAxes(ArrayRef<bool> reducedAxes) {
SmallVector<int64_t> keptAxes;
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes))
if (!isReduced)
keptAxes.push_back(static_cast<int64_t>(axis));
return keptAxes;
}
static Value
computeLaneIndex(Value lane, int64_t stride, int64_t dimSize, ConversionPatternRewriter& rewriter, Location loc) {
if (dimSize == 1)
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
MLIRContext* context = rewriter.getContext();
AffineExpr d0 = getAffineDimExpr(0, context);
AffineExpr expr = d0;
if (stride != 1)
expr = expr.floorDiv(stride);
if (dimSize != 1)
expr = expr % dimSize;
return createOrFoldAffineApply(rewriter, loc, expr, ValueRange {lane}, rewriter.getInsertionBlock()->getParentOp());
}
static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,
ArrayRef<bool> reducedAxes,
RankedTensorType batchType,
RankedTensorType leafType,
ConversionPatternRewriter& rewriter,
Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
auto sliceType = getReducedSliceType(inputType, reducedAxes);
SmallVector<int64_t> keptAxes = getKeptAxes(reducedAxes);
int64_t laneCount = 1;
SmallVector<int64_t> keptAxisStrides(keptAxes.size(), 1);
for (int64_t index = static_cast<int64_t>(keptAxes.size()) - 1; index >= 0; --index) {
keptAxisStrides[index] = laneCount;
int64_t dimSize = inputType.getDimSize(keptAxes[index]);
if (dimSize <= 0)
return failure();
if (laneCount > std::numeric_limits<int32_t>::max() / dimSize)
return failure();
laneCount *= dimSize;
}
SmallVector<OpFoldResult> sliceOffsets;
SmallVector<OpFoldResult> sliceSizes;
SmallVector<OpFoldResult> insertOffsets;
SmallVector<OpFoldResult> insertSizes(inputType.getRank(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, inputType.getRank());
sliceOffsets.reserve(inputType.getRank());
sliceSizes.reserve(inputType.getRank());
insertOffsets.reserve(inputType.getRank());
auto batchOp =
createSpatComputeBatch(rewriter,
loc,
TypeRange {batchType},
laneCount,
{},
ValueRange {input},
[&](detail::SpatComputeBatchBodyArgs args) {
size_t keptAxisIndex = 0;
sliceOffsets.clear();
sliceSizes.clear();
insertOffsets.clear();
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
if (isReduced) {
sliceOffsets.push_back(rewriter.getIndexAttr(0));
sliceSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(axis)));
continue;
}
Value axisIndex = computeLaneIndex(
args.lane, keptAxisStrides[keptAxisIndex], inputType.getDimSize(axis), rewriter, loc);
++keptAxisIndex;
sliceOffsets.push_back(axisIndex);
sliceSizes.push_back(rewriter.getIndexAttr(1));
}
insertOffsets.push_back(args.lane);
insertOffsets.append(inputType.getRank() - 1, rewriter.getIndexAttr(0));
Value slice = tensor::ExtractSliceOp::create(
rewriter, loc, sliceType, args.inputs.front(), sliceOffsets, sliceSizes, unitStrides);
Value reduced = spatial::SpatVAvgOp::create(rewriter, loc, leafType, slice).getResult();
createParallelInsertSliceIntoBatchOutput(
rewriter, loc, reduced, args.outputs.front(), insertOffsets, insertSizes, unitStrides);
});
if (failed(batchOp))
return failure();
return (*batchOp).getResult(0);
}
static Value buildKeepdimsFromLanePackedBatch(Value batchValue,
RankedTensorType keepdimsType,
RankedTensorType compactKeptType,
ArrayRef<bool> reducedAxes,
ConversionPatternRewriter& rewriter,
Location loc) {
auto batchType = cast<RankedTensorType>(batchValue.getType());
if (batchType == keepdimsType)
return batchValue;
SmallVector<ReassociationIndices> collapseToFlat {{}};
for (int64_t axis = 0; axis < batchType.getRank(); ++axis)
collapseToFlat.front().push_back(axis);
SmallVector<ReassociationIndices> expandFlatToCompact(1);
for (int64_t axis = 0; axis < compactKeptType.getRank(); ++axis)
expandFlatToCompact.front().push_back(axis);
SmallVector<ReassociationIndices> expandCompactToKeepdims;
ReassociationIndices pendingLeadingReducedAxes;
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
if (isReduced) {
if (expandCompactToKeepdims.empty())
pendingLeadingReducedAxes.push_back(axis);
else
expandCompactToKeepdims.back().push_back(axis);
continue;
}
expandCompactToKeepdims.emplace_back();
auto& group = expandCompactToKeepdims.back();
group.append(pendingLeadingReducedAxes.begin(), pendingLeadingReducedAxes.end());
pendingLeadingReducedAxes.clear();
group.push_back(axis);
}
if (!pendingLeadingReducedAxes.empty())
expandCompactToKeepdims.back().append(pendingLeadingReducedAxes.begin(), pendingLeadingReducedAxes.end());
auto reshapeCompute =
createSpatCompute<1>(rewriter, loc, TypeRange {keepdimsType}, {}, ValueRange {batchValue}, [&](Value input) {
auto flatType =
RankedTensorType::get({batchType.getDimSize(0)}, batchType.getElementType(), batchType.getEncoding());
Value flat = tensor::CollapseShapeOp::create(rewriter, loc, flatType, input, collapseToFlat);
Value compact = flat;
if (compactKeptType != flatType)
compact = tensor::ExpandShapeOp::create(rewriter, loc, compactKeptType, flat, expandFlatToCompact);
Value keepdims = compact;
if (keepdimsType != compactKeptType)
keepdims = tensor::ExpandShapeOp::create(rewriter, loc, keepdimsType, compact, expandCompactToKeepdims);
spatial::SpatYieldOp::create(rewriter, loc, keepdims);
});
return reshapeCompute.getResult(0);
}
static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> reducedAxes) {
SmallVector<ReassociationIndices> reassociation;
ReassociationIndices currentGroup;
@@ -72,56 +233,6 @@ static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<boo
return reassociation;
}
static Value
createAverageCompute(Value input, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
constexpr size_t numInputs = 1;
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) {
auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x);
spatial::SpatYieldOp::create(rewriter, loc, avgOp.getResult());
});
return computeOp.getResult(0);
}
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
auto firstType = cast<RankedTensorType>(inputs.front().getType());
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
int64_t concatDimSize = 0;
for (Value input : inputs)
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
outputShape[axis] = concatDimSize;
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
if (llvm::all_of(inputs, isHostFoldableValue))
return createSpatConcat(rewriter, loc, axis, inputs);
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
});
return concatCompute.getResult(0);
}
static Value buildReduceMeanKeepdims(Value input,
ArrayRef<bool> reducedAxes,
int64_t axis,
RankedTensorType leafType,
ConversionPatternRewriter& rewriter,
Location loc) {
int64_t rank = cast<RankedTensorType>(input.getType()).getRank();
if (axis == rank)
return createAverageCompute(input, leafType, rewriter, loc);
if (reducedAxes[axis])
return buildReduceMeanKeepdims(input, reducedAxes, axis + 1, leafType, rewriter, loc);
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
SmallVector<Value> reducedSlices;
reducedSlices.reserve(slices.size());
for (Value slice : slices)
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
return concatValues(reducedSlices, axis, rewriter, loc);
}
static Value squeezeReducedAxes(Value keepdimsValue,
RankedTensorType resultType,
ArrayRef<bool> reducedAxes,
@@ -129,13 +240,13 @@ static Value squeezeReducedAxes(Value keepdimsValue,
Location loc) {
if (resultType.getRank() == 0) {
SmallVector<Value> indices(cast<RankedTensorType>(keepdimsValue.getType()).getRank(),
arith::ConstantIndexOp::create(rewriter, loc, 0));
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0));
Value element = tensor::ExtractOp::create(rewriter, loc, keepdimsValue, indices);
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
}
auto reassociation = buildCollapseReassociation(reducedAxes);
if (isHostFoldableValue(keepdimsValue))
if (isCompileTimeComputable(keepdimsValue))
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
auto squeezeCompute =
@@ -156,16 +267,33 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
auto resultType = dyn_cast<RankedTensorType>(reduceMeanOp.getReduced().getType());
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
if (inputType.getRank() == 0) {
rewriter.replaceOp(reduceMeanOp, adaptor.getData());
return success();
}
SmallVector<int64_t> axes = normalizeAxes(reduceMeanOp.getAxesAttr(), inputType.getRank());
SmallVector<bool> reducedAxes = buildReducedAxesMask(axes, inputType.getRank());
auto axes = normalizeAxesChecked(std::optional<ArrayAttr>(reduceMeanOp.getAxesAttr()), inputType.getRank());
if (failed(axes))
return failure();
SmallVector<bool> reducedAxes = buildReducedAxesMask(*axes, inputType.getRank());
if (reducedAxes.empty() && inputType.getRank() != 0)
return failure();
Location loc = reduceMeanOp.getLoc();
RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType());
RankedTensorType compactKeptType = getCompactKeptType(inputType, resultType.getElementType(), reducedAxes);
RankedTensorType keepdimsType = getKeepdimsType(inputType, resultType.getElementType(), reducedAxes);
int64_t laneCount = 1;
for (int64_t dim : compactKeptType.getShape())
laneCount *= dim;
RankedTensorType batchType = getLanePackedKeepdimsType(laneCount, leafType);
auto lanePackedKeepdims =
buildReduceMeanKeepdimsBatch(adaptor.getData(), reducedAxes, batchType, leafType, rewriter, loc);
if (failed(lanePackedKeepdims))
return failure();
Value reducedKeepdims =
buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc);
buildKeepdimsFromLanePackedBatch(*lanePackedKeepdims, keepdimsType, compactKeptType, reducedAxes, rewriter, loc);
if (reduceMeanOp.getKeepdims() != 0) {
rewriter.replaceOp(reduceMeanOp, reducedKeepdims);
@@ -23,43 +23,26 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
template <typename ArrayAttrT>
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
return cast<IntegerAttr>(arrayAttr[index]).getInt();
}
template <typename ArrayAttrT>
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
}
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
static Value materializeTileTensor(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
auto tileType = cast<RankedTensorType>(tile.getType());
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
sizes.reserve(tileType.getRank());
for (int64_t dimSize : tileType.getShape())
sizes.push_back(rewriter.getIndexAttr(dimSize));
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
return insertStaticSlice(rewriter, loc, tile, empty, getZeroOffsets(rewriter, tileType.getRank()));
}
static Value
createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
if (!useMinimumValue)
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
return getOrCreateConstant(rewriter, anchorOp, rewriter.getZeroAttr(elementType), elementType);
if (auto floatType = dyn_cast<FloatType>(elementType)) {
auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true);
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue));
return getOrCreateConstant(rewriter, anchorOp, rewriter.getFloatAttr(floatType, minValue), elementType);
}
if (auto integerType = dyn_cast<IntegerType>(elementType)) {
auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth());
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue));
return getOrCreateConstant(rewriter, anchorOp, rewriter.getIntegerAttr(integerType, minValue), elementType);
}
llvm_unreachable("unsupported pool element type");
@@ -166,7 +149,7 @@ static FailureOr<Value> createAverageScaleTensor(ConversionPatternRewriter& rewr
}
auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues);
return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult();
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scaleAttr, scaleType);
}
template <typename PoolOp>
@@ -197,12 +180,12 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
const int64_t inputWidth = xType.getDimSize(3);
const int64_t outputHeight = outType.getDimSize(2);
const int64_t outputWidth = outType.getDimSize(3);
const int64_t kernelHeight = getI64(kernelAttr, 0);
const int64_t kernelWidth = getI64(kernelAttr, 1);
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1);
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1);
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1);
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1);
const int64_t kernelHeight = getI64Attr(kernelAttr, 0);
const int64_t kernelWidth = getI64Attr(kernelAttr, 1);
const int64_t strideHeight = getOptionalI64Attr(poolOp.getStrides(), 0, 1);
const int64_t strideWidth = getOptionalI64Attr(poolOp.getStrides(), 1, 1);
const int64_t dilationHeight = getOptionalI64Attr(poolOp.getDilations(), 0, 1);
const int64_t dilationWidth = getOptionalI64Attr(poolOp.getDilations(), 1, 1);
int64_t padTop = 0;
int64_t padLeft = 0;
@@ -212,10 +195,10 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
if (auto padsAttr = poolOp.getPads()) {
if (padsAttr->size() != 4)
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
padTop = getI64(*padsAttr, 0);
padLeft = getI64(*padsAttr, 1);
padBottom = getI64(*padsAttr, 2);
padRight = getI64(*padsAttr, 3);
padTop = getI64Attr(*padsAttr, 0);
padLeft = getI64Attr(*padsAttr, 1);
padBottom = getI64Attr(*padsAttr, 2);
padRight = getI64Attr(*padsAttr, 3);
}
else {
StringRef autoPad = poolOp.getAutoPad();
@@ -283,13 +266,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount);
Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth);
Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth);
Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cOutputPatchCount = getOrCreateIndexConstant(rewriter, anchorOp, outputPatchCount);
Value cOutputPixelsPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, outputHeight * outputWidth);
Value cOutputWidth = getOrCreateIndexConstant(rewriter, anchorOp, outputWidth);
Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight);
Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth);
auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit});
rewriter.setInsertionPointToStart(outputLoop.getBody());
@@ -314,14 +298,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
Value paddedInH = windowBaseH;
if (kernelH * dilationHeight != 0) {
Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight);
Value kernelHOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelH * dilationHeight);
paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset);
}
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
Value paddedInW = windowBaseW;
if (kernelW * dilationWidth != 0) {
Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth);
Value kernelWOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelW * dilationWidth);
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
}
@@ -335,7 +319,7 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value windowValue =
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
windowValue = materializeTileTensor(rewriter, loc, windowValue);
reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue);
}
}
@@ -351,7 +335,7 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
Value scaleSlice = tensor::ExtractSliceOp::create(
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
scaleSlice = materializeTileTensor(rewriter, loc, scaleSlice);
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
}
@@ -4,7 +4,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -13,16 +13,6 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
SmallVector<int64_t> permutedShape;
permutedShape.reserve(permutation.size());
for (int64_t axis : permutation)
permutedShape.push_back(shape[axis]);
return permutedShape;
}
static Value buildLoopSoftmaxSlice(Value input,
Value accumulator,
RankedTensorType inputType,
@@ -36,7 +26,7 @@ static Value buildLoopSoftmaxSlice(Value input,
SmallVector<OpFoldResult> offsets;
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, rank);
offsets.reserve(rank);
sizes.reserve(rank);
@@ -62,9 +52,10 @@ static Value buildLoopSoftmaxNest(Value input,
if (axis == inputType.getRank() - 1)
return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc);
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
Value cUpper = arith::ConstantIndexOp::create(rewriter, loc, inputType.getDimSize(axis));
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cUpper = getOrCreateIndexConstant(rewriter, anchorOp, inputType.getDimSize(axis));
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator});
rewriter.setInsertionPointToStart(loop.getBody());
@@ -110,44 +101,29 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
if (!inputType || !inputType.hasStaticShape())
return failure();
int64_t axis = normalizeAxis(softmaxOp.getAxis(), inputType.getRank());
if (axis < 0 || axis >= inputType.getRank())
auto axis = normalizeAxisChecked(softmaxOp.getAxis(), inputType.getRank());
if (failed(axis))
return failure();
Value input = adaptor.getInput();
Value result;
if (axis == inputType.getRank() - 1) {
if (*axis == inputType.getRank() - 1) {
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
}
else {
SmallVector<int64_t> permutation;
permutation.reserve(inputType.getRank());
for (int64_t dim = 0; dim < inputType.getRank(); ++dim)
if (dim != axis)
if (dim != *axis)
permutation.push_back(dim);
permutation.push_back(axis);
SmallVector<int64_t> inversePermutation(inputType.getRank());
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
permutation.push_back(*axis);
SmallVector<int64_t> inversePermutation = invertPermutation(permutation);
auto transposedType = RankedTensorType::get(
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
auto preTransposeCompute =
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) {
Value transposed = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation));
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
});
Value transposedInput = preTransposeCompute.getResult(0);
Value transposedInput = transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc());
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
auto postTransposeCompute =
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
Value transposed = ONNXTransposeOp::create(
rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation));
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
});
result = postTransposeCompute.getResult(0);
result = transposeMaybeInCompute(transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc());
}
rewriter.replaceOp(softmaxOp, result);
@@ -0,0 +1,288 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isWeightMaterializationHelperUser(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(op);
}
static bool canPromoteInputBlockArgument(BlockArgument arg) {
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
}
static bool canPromoteInputBlockArgument(std::optional<BlockArgument> arg) {
return arg && canPromoteInputBlockArgument(*arg);
}
static bool isDirectConstantValue(Value value) {
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
}
struct PromotedOperands {
SmallVector<bool> promoteInput;
SmallVector<Value> newWeights;
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
};
template <typename ComputeOpTy>
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
continue;
return true;
}
return false;
}
template <typename ComputeOpTy>
static FailureOr<PromotedOperands> computePromotedOperands(ComputeOpTy compute) {
PromotedOperands promoted;
promoted.promoteInput.assign(compute.getInputs().size(), false);
promoted.newWeights.append(compute.getWeights().begin(), compute.getWeights().end());
promoted.newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
promoted.newInputs.reserve(compute.getInputs().size());
promoted.newInputTypes.reserve(compute.getInputs().size());
promoted.newInputLocs.reserve(compute.getInputs().size());
bool needsRewrite = false;
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (!isWeightLikeComputeOperand(input))
goto keep_input;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
goto keep_input;
promoted.promoteInput[inputIdx] = true;
promoted.newWeights.push_back(input);
needsRewrite = true;
continue;
keep_input:
promoted.newInputs.push_back(input);
promoted.newInputTypes.push_back(input.getType());
promoted.newInputLocs.push_back(input.getLoc());
}
if (!needsRewrite)
return failure();
return promoted;
}
template <typename ComputeOpTy>
static LogicalResult mapPromotedInputArguments(ComputeOpTy compute,
const PromotedOperands& promoted,
IRRewriter& bodyRewriter,
IRMapping& mapper,
std::function<std::optional<BlockArgument>(size_t)> getNewInputArg,
PatternRewriter& rewriter) {
size_t newInputIdx = 0;
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
auto oldArg = compute.getInputArgument(oldInputIdx);
if (!oldArg)
return rewriter.notifyMatchFailure(compute, "missing input block argument during rewrite");
if (!promoted.promoteInput[oldInputIdx]) {
auto newInputArg = getNewInputArg(newInputIdx++);
if (!newInputArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten input block argument");
mapper.map(*oldArg, *newInputArg);
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(*oldArg, *clonedValue);
}
return success();
}
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
auto promoted = computePromotedOperands(compute);
if (failed(promoted))
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
Block& oldBlock = compute.getBody().front();
rewriter.setInsertionPointAfter(compute);
auto newCompute = spatial::SpatCompute::create(
rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
for (Value weight : promoted->newWeights) {
newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc());
}
llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
auto oldWeightArg = compute.getWeightArgument(weightIndex);
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
if (!oldWeightArg || !newWeightArg)
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
mapper.map(*oldWeightArg, *newWeightArg);
}
if (failed(mapPromotedInputArguments(
compute,
*promoted,
bodyRewriter,
mapper,
[&](size_t index) { return newCompute.getInputArgument(index); },
rewriter)))
return failure();
for (Operation& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
};
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
auto promoted = computePromotedOperands(compute);
if (failed(promoted))
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
Block& oldBlock = compute.getBody().front();
rewriter.setInsertionPointAfter(compute);
auto newCompute =
spatial::SpatComputeBatch::create(rewriter,
compute.getLoc(),
compute.getResultTypes(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
promoted->newWeights,
promoted->newInputs);
auto laneArg = compute.getLaneArgument();
if (!laneArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
SmallVector<Type> newBlockArgTypes;
SmallVector<Location> newBlockArgLocs;
newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size()
+ compute.getNumResults());
newBlockArgLocs.reserve(1 + promoted->newWeights.size() + promoted->newInputLocs.size() + compute.getNumResults());
newBlockArgTypes.push_back(laneArg->getType());
newBlockArgLocs.push_back(laneArg->getLoc());
for (Value weight : promoted->newWeights) {
newBlockArgTypes.push_back(weight.getType());
newBlockArgLocs.push_back(weight.getLoc());
}
llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
auto outputArg = compute.getOutputArgument(resultIndex);
if (!outputArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument");
newBlockArgTypes.push_back(resultType);
newBlockArgLocs.push_back(outputArg->getLoc());
}
auto* newBlock = rewriter.createBlock(
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
auto newLaneArg = newCompute.getLaneArgument();
if (!newLaneArg)
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch lane block argument");
mapper.map(*laneArg, *newLaneArg);
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
auto oldWeightArg = compute.getWeightArgument(weightIndex);
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
if (!oldWeightArg || !newWeightArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
mapper.map(*oldWeightArg, *newWeightArg);
}
if (failed(mapPromotedInputArguments(
compute,
*promoted,
bodyRewriter,
mapper,
[&](size_t index) { return newCompute.getInputArgument(index); },
rewriter)))
return failure();
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
auto outputArg = compute.getOutputArgument(resultIndex);
if (!outputArg)
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
mapper.map(*outputArg,
newBlock->getArgument(1 + promoted->newWeights.size() + promoted->newInputs.size() + resultIndex));
}
for (Operation& op : oldBlock)
rewriter.clone(op, mapper);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
};
} // namespace
void populateWeightPromotionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
}
void annotateWeightsConstants(func::FuncOp funcOp) {
funcOp.walk([&](arith::ConstantOp constantOp) {
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
markWeightAlways(constantOp);
});
}
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
} // namespace onnx_mlir
@@ -1,6 +1,5 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
using namespace mlir;
@@ -12,7 +11,7 @@ namespace {
} // namespace
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
patterns.add<onnxToArithConstant>(ctx);
patterns.add<convAddToConvWithBiasLeft>(ctx);
patterns.add<convAddToConvWithBiasRight>(ctx);
@@ -3,7 +3,7 @@
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -20,7 +20,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
auto inputs = adaptor.getInputs();
int64_t axis = adaptor.getAxis();
if (llvm::all_of(inputs, isHostFoldableValue)) {
if (llvm::all_of(inputs, isCompileTimeComputable)) {
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
return success();
}
@@ -6,7 +6,7 @@
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -15,24 +15,6 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
static int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
static Value
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(inputType.getRank());
for (int64_t dim : inputType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(1);
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
}
static Value concatGatherSlices(Value data,
int64_t axis,
ArrayRef<int64_t> indices,
@@ -45,7 +27,7 @@ static Value concatGatherSlices(Value data,
int64_t normalizedIndex = normalizeIndex(index, axisDim);
if (normalizedIndex < 0 || normalizedIndex >= axisDim)
return {};
slices.push_back(extractSliceAt(data, axis, normalizedIndex, rewriter, loc));
slices.push_back(extractAxisSlice(rewriter, loc, data, axis, normalizedIndex, /*size=*/1));
}
if (slices.empty())
return {};
@@ -96,11 +78,11 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
return failure();
int64_t rank = dataType.getRank();
int64_t axis = normalizeAxis(gatherOp.getAxis(), rank);
if (axis < 0 || axis >= rank)
auto axis = normalizeAxisChecked(gatherOp.getAxis(), rank);
if (failed(axis))
return failure();
int64_t axisDim = dataType.getShape()[axis];
int64_t axisDim = dataType.getShape()[*axis];
if (axisDim <= 0)
return failure();
@@ -116,7 +98,7 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
[&](Value data) -> LogicalResult {
Value result;
if (indicesType.getRank() == 1) {
result = concatGatherSlices(data, axis, flatIndices, axisDim, rewriter, loc);
result = concatGatherSlices(data, *axis, flatIndices, axisDim, rewriter, loc);
}
else if (indicesType.getRank() == 2) {
int64_t rowCount = indicesType.getShape()[0];
@@ -125,12 +107,13 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
rows.reserve(rowCount);
for (int64_t row = 0; row < rowCount; ++row) {
ArrayRef<int64_t> rowIndices(flatIndices.data() + row * rowWidth, rowWidth);
Value gatheredRow = concatGatherSlices(data, axis, rowIndices, axisDim, rewriter, loc);
Value gatheredRow =
concatGatherSlices(data, *axis, rowIndices, axisDim, rewriter, loc);
if (!gatheredRow)
return failure();
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
rows.push_back(addLeadingGatherDim(gatheredRow, *axis, rewriter, loc));
}
result = createSpatConcat(rewriter, loc, /*axis=*/axis, rows);
result = createSpatConcat(rewriter, loc, /*axis=*/*axis, rows);
}
else {
return failure();
@@ -4,8 +4,8 @@
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -14,10 +14,6 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
}
static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape,
ArrayRef<int64_t> resultShape,
SmallVector<ReassociationIndices>& reassociation) {
@@ -106,7 +102,7 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType());
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape()))
if (!hasStaticPositiveShape(sourceType) || !hasStaticPositiveShape(resultType))
return failure();
if (sourceType == resultType) {
@@ -115,17 +111,9 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
}
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
if (isHostFoldableValue(adaptor.getData())) {
rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData()));
return success();
}
auto computeOp = createSpatCompute<1>(
rewriter, reshapeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
Value reshaped = buildReshape(data);
spatial::SpatYieldOp::create(rewriter, reshapeOp.getLoc(), reshaped);
});
rewriter.replaceOp(reshapeOp, computeOp.getResults());
Value reshaped =
materializeOrComputeUnary(adaptor.getData(), resultType, rewriter, reshapeOp.getLoc(), buildReshape);
rewriter.replaceOp(reshapeOp, reshaped);
return success();
};
@@ -6,7 +6,7 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -17,9 +17,10 @@ namespace {
static Value buildNearestAsymmetricIndex(
Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) {
Value cInputDim = arith::ConstantIndexOp::create(rewriter, loc, inputDim);
Value cOutputDim = arith::ConstantIndexOp::create(rewriter, loc, outputDim);
Value cInputDimLast = arith::ConstantIndexOp::create(rewriter, loc, inputDim - 1);
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value cInputDim = getOrCreateIndexConstant(rewriter, anchorOp, inputDim);
Value cOutputDim = getOrCreateIndexConstant(rewriter, anchorOp, outputDim);
Value cInputDimLast = getOrCreateIndexConstant(rewriter, anchorOp, inputDim - 1);
Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim);
Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim);
return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast);
@@ -37,12 +38,13 @@ static Value buildNearestResizeLoop(Value input,
SmallVector<OpFoldResult> unitSizes(resultType.getRank(), rewriter.getIndexAttr(1));
SmallVector<OpFoldResult> unitStrides(resultType.getRank(), rewriter.getIndexAttr(1));
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
Value cOutputN = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(0));
Value cOutputC = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(1));
Value cOutputH = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(2));
Value cOutputW = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(3));
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
Value cOutputN = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(0));
Value cOutputC = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(1));
Value cOutputH = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(2));
Value cOutputW = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(3));
Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType);
@@ -58,24 +60,21 @@ static Value buildNearestResizeLoop(Value input,
Value outputC = channelLoop.getInductionVar();
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
Value inputC =
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
Value inputC = buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
rewriter.setInsertionPointToStart(heightLoop.getBody());
Value outputH = heightLoop.getInductionVar();
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
Value inputH =
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
Value inputH = buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
rewriter.setInsertionPointToStart(widthLoop.getBody());
Value outputW = widthLoop.getInductionVar();
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
Value inputW =
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
Value inputW = buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
Value inputSlice =
@@ -114,8 +113,8 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|| resizeOp.getNearestMode() != "floor")
return rewriter.notifyMatchFailure(
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor.");
return rewriter.notifyMatchFailure(resizeOp,
"resize lowering currently supports only nearest + asymmetric + floor.");
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
@@ -2,8 +2,8 @@
#include "mlir/Transforms/DialectConversion.h"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -12,25 +12,6 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
static Value extractSliceAt(
Value input, int64_t axis, int64_t offset, int64_t size, ConversionPatternRewriter& rewriter, Location loc) {
auto inputType = cast<RankedTensorType>(input.getType());
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sizes;
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
sizes.reserve(inputType.getRank());
for (int64_t dim : inputType.getShape())
sizes.push_back(rewriter.getIndexAttr(dim));
offsets[axis] = rewriter.getIndexAttr(offset);
sizes[axis] = rewriter.getIndexAttr(size);
SmallVector<int64_t> resultShape(inputType.getShape());
resultShape[axis] = size;
auto resultType = RankedTensorType::get(resultShape, inputType.getElementType());
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, input, offsets, sizes, strides);
}
struct Split : OpConversionPattern<ONNXSplitOp> {
using OpConversionPattern::OpConversionPattern;
@@ -41,8 +22,8 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
return failure();
int64_t rank = inputType.getRank();
int64_t axis = normalizeAxis(splitOp.getAxis(), rank);
if (axis < 0 || axis >= rank)
auto axis = normalizeAxisChecked(splitOp.getAxis(), rank);
if (failed(axis))
return failure();
SmallVector<Value> outputs;
@@ -58,12 +39,12 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
if (!resultType || !resultType.hasStaticShape())
return failure();
resultTypes.push_back(resultType);
sliceSizes.push_back(resultType.getShape()[axis]);
sliceSizes.push_back(resultType.getShape()[*axis]);
}
if (isHostFoldableValue(adaptor.getInput())) {
if (isCompileTimeComputable(adaptor.getInput())) {
for (int64_t sliceSize : sliceSizes) {
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
outputs.push_back(extractAxisSlice(rewriter, splitOp.getLoc(), adaptor.getInput(), *axis, offset, sliceSize));
offset += sliceSize;
}
rewriter.replaceOp(splitOp, outputs);
@@ -76,7 +57,8 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
runtimeOutputs.reserve(resultTypes.size());
int64_t runtimeOffset = 0;
for (int64_t sliceSize : sliceSizes) {
runtimeOutputs.push_back(extractSliceAt(input, axis, runtimeOffset, sliceSize, rewriter, splitOp.getLoc()));
runtimeOutputs.push_back(
extractAxisSlice(rewriter, splitOp.getLoc(), input, *axis, runtimeOffset, sliceSize));
runtimeOffset += sliceSize;
}
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), runtimeOutputs);
@@ -0,0 +1,119 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static Value createTransposeInit(Value input,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
SmallVector<OpFoldResult> sizes;
sizes.reserve(resultType.getRank());
for (auto [resultDim, sourceDim] : llvm::zip_equal(resultType.getShape(), permutation)) {
if (!ShapedType::isDynamic(resultDim)) {
sizes.push_back(rewriter.getIndexAttr(resultDim));
continue;
}
sizes.push_back(tensor::DimOp::create(rewriter, loc, input, sourceDim).getResult());
}
return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult();
}
static FailureOr<Value> materializeTransposedConstant(Value input,
RankedTensorType resultType,
ArrayRef<int64_t> permutation,
ConversionPatternRewriter& rewriter,
Location loc) {
auto denseAttr = getHostConstDenseElementsAttr(input);
if (!denseAttr)
return failure();
auto inputType = dyn_cast<RankedTensorType>(denseAttr.getType());
if (!inputType || !inputType.hasStaticShape() || !resultType.hasStaticShape()
|| inputType.getRank() != resultType.getRank()
|| static_cast<int64_t>(permutation.size()) != inputType.getRank()) {
return failure();
}
if (denseAttr.isSplat())
return getOrCreateConstant(rewriter,
rewriter.getInsertionBlock()->getParentOp(),
DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>()),
resultType);
SmallVector<Attribute> inputValues(denseAttr.getValues<Attribute>());
SmallVector<Attribute> resultValues(inputValues.size());
SmallVector<int64_t> inputStrides = computeRowMajorStrides(inputType.getShape());
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
SmallVector<int64_t> inputIndices(inputType.getRank(), 0);
for (auto [linearIndex, value] : llvm::enumerate(inputValues)) {
int64_t remaining = static_cast<int64_t>(linearIndex);
for (int64_t dim = 0; dim < inputType.getRank(); ++dim) {
inputIndices[dim] = inputStrides.empty() ? 0 : remaining / inputStrides[dim];
remaining = inputStrides.empty() ? 0 : remaining % inputStrides[dim];
}
int64_t resultLinearIndex = 0;
for (int64_t dim = 0; dim < resultType.getRank(); ++dim)
resultLinearIndex += inputIndices[permutation[dim]] * resultStrides[dim];
resultValues[resultLinearIndex] = value;
}
return getOrCreateConstant(rewriter,
rewriter.getInsertionBlock()->getParentOp(),
DenseElementsAttr::get(resultType, resultValues),
resultType);
}
struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(ONNXTransposeOp transposeOp,
ONNXTransposeOpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto inputType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
auto resultType = dyn_cast<RankedTensorType>(transposeOp.getResult().getType());
if (!inputType || !resultType)
return failure();
auto permutation = getTransposePermutationChecked(transposeOp.getPermAttr(), inputType.getRank());
if (failed(permutation))
return failure();
if (isCompileTimeComputable(adaptor.getData())) {
auto constantTranspose =
materializeTransposedConstant(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
if (succeeded(constantTranspose)) {
rewriter.replaceOp(transposeOp, *constantTranspose);
return success();
}
}
Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
Value transposed =
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation).getResult()[0];
rewriter.replaceOp(transposeOp, transposed);
return success();
}
};
} // namespace
void populateTransposePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<TransposeToLinalgTranspose>(ctx);
}
} // namespace onnx_mlir
@@ -1,286 +0,0 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
static bool isWeightMaterializationHelperUser(Operation* op) {
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(op);
}
static bool canPromoteInputBlockArgument(BlockArgument arg) {
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
}
static bool isDirectConstantValue(Value value) {
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
}
template <typename ComputeOpTy>
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
Block& block = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= block.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(block.getArgument(inputIdx)))
continue;
return true;
}
return false;
}
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override {
if (batchOp.getLaneCount() != 1)
return rewriter.notifyMatchFailure(batchOp, "requires a single lane");
auto loc = batchOp.getLoc();
rewriter.setInsertionPoint(batchOp);
auto computeOp =
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
computeOp.getProperties().setOperandSegmentSizes(
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
Block& templateBlock = batchOp.getBody().front();
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
blockArgTypes.reserve(templateBlock.getNumArguments());
blockArgLocs.reserve(templateBlock.getNumArguments());
for (BlockArgument arg : templateBlock.getArguments()) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(loc);
}
auto* newBlock =
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
IRMapping mapper;
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
mapper.map(oldArg, newArg);
rewriter.setInsertionPointToEnd(newBlock);
for (Operation& op : templateBlock)
rewriter.clone(op, mapper);
batchOp->replaceAllUsesWith(computeOp->getResults());
rewriter.eraseOp(batchOp);
return success();
}
};
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
bool needsRewrite = false;
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= oldBlock.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute =
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
mapper.map(oldArg, *clonedValue);
}
for (Operation& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
};
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
bool needsRewrite = false;
Block& oldBlock = compute.getBody().front();
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (inputIdx >= oldBlock.getNumArguments())
continue;
if (!isWeightLikeComputeOperand(input))
continue;
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
continue;
promoteInput[inputIdx] = true;
needsRewrite = true;
}
if (!needsRewrite)
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
rewriter.setInsertionPointAfter(compute);
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
SmallVector<Value> newInputs;
SmallVector<Type> newInputTypes;
SmallVector<Location> newInputLocs;
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
newInputs.reserve(compute.getInputs().size());
newInputTypes.reserve(compute.getInputs().size());
newInputLocs.reserve(compute.getInputs().size());
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
if (promoteInput[inputIdx]) {
newWeights.push_back(input);
continue;
}
newInputs.push_back(input);
newInputTypes.push_back(input.getType());
newInputLocs.push_back(input.getLoc());
}
auto newCompute =
spatial::SpatComputeBatch::create(rewriter,
compute.getLoc(),
compute.getResultTypes(),
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
newWeights,
newInputs);
auto* newBlock =
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
newCompute.getProperties().setOperandSegmentSizes(
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
rewriter.setInsertionPointToStart(newBlock);
IRRewriter bodyRewriter(rewriter.getContext());
bodyRewriter.setInsertionPointToStart(newBlock);
IRMapping mapper;
size_t newInputIdx = 0;
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
if (!promoteInput[oldInputIdx]) {
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
continue;
}
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
if (failed(clonedValue))
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
mapper.map(oldArg, *clonedValue);
}
for (Operation& op : oldBlock.without_terminator())
rewriter.clone(op, mapper);
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
SmallVector<Value> newYieldOperands;
newYieldOperands.reserve(oldYield.getOutputs().size());
for (Value operand : oldYield.getOutputs()) {
auto mapped = mapper.lookupOrNull(operand);
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
}
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
rewriter.replaceOp(compute, newCompute.getResults());
return success();
}
};
} // namespace
void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<FoldSingleLaneComputeBatchPattern>(ctx);
}
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
}
void annotateWeightsConstants(func::FuncOp funcOp) {
funcOp.walk([&](arith::ConstantOp constantOp) {
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
markWeightAlways(constantOp);
});
}
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp) { return batchOp.getLaneCount() == 1; }
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
} // namespace onnx_mlir
@@ -1,22 +0,0 @@
#pragma once
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/MLIRContext.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp);
bool requiresPostRewrite(spatial::SpatCompute computeOp);
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
} // namespace onnx_mlir
@@ -1,10 +0,0 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include "mlir/Transforms/DialectConversion.h"
namespace onnx_mlir {
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
} // namespace onnx_mlir
@@ -1,11 +1,15 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
@@ -15,7 +19,11 @@ using namespace onnx_mlir::pim;
namespace onnx_mlir {
namespace {
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static bool isUsedOnlyAsExplicitHostOperand(Value value) {
return !value.use_empty() && llvm::all_of(value.getUses(), [](OpOperand& use) {
return isExplicitDevToHostTargetOperand(use.getOwner(), use.getOperandNumber());
});
}
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
@@ -28,54 +36,73 @@ static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch co
return coreIds;
}
static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size());
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
if (!result.hasOneUse())
return failure();
pim::PimSendTensorBatchOp::create(rewriter,
sendTensorBatchOp.getLoc(),
mapper.lookup(sendTensorBatchOp.getInput()),
rewriter.getDenseI32ArrayAttr(targetCoreIds));
auto returnOp = dyn_cast<func::ReturnOp>(*result.getUsers().begin());
if (!returnOp)
return failure();
return result.getUses().begin()->getOperandNumber();
}
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
IRMapping& mapper,
IRRewriter& rewriter) {
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size());
for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds())
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
if (scale == 1)
return base;
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
Value received = pim::PimReceiveTensorBatchOp::create(rewriter,
receiveTensorBatchOp.getLoc(),
outputBuffer.getType(),
outputBuffer,
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
.getOutput();
mapper.map(receiveTensorBatchOp.getOutput(), received);
auto scaleValue = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scale);
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
}
static Value createHostTargetOffset(IRRewriter& rewriter,
tensor::ParallelInsertSliceOp insertSlice,
ShapedType destinationType,
IRMapping& mapper) {
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
Value totalOffset;
Location loc = insertSlice.getLoc();
for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) {
int64_t scale = strides[dim] * elementBytes;
Value scaledOffset;
if (auto attr = dyn_cast<Attribute>(offset)) {
auto intAttr = dyn_cast<IntegerAttr>(attr);
assert(intAttr && "expected integer offset attribute");
scaledOffset =
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), intAttr.getInt() * scale);
}
else {
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
}
totalOffset =
totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult() : scaledOffset;
}
if (!totalOffset)
totalOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
return totalOffset;
}
} // namespace
LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
if (computeBatchOp.getNumResults() != 0)
return computeBatchOp.emitOpError(
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
IRRewriter& rewriter) {
Location loc = computeBatchOp.getLoc();
Block& oldBlock = computeBatchOp.getBody().front();
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
if (oldYield.getNumOperands() != 0)
return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield");
auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
auto inParallelOp = dyn_cast<spatial::SpatInParallelOp>(oldBlock.getTerminator());
if (computeBatchOp.getNumResults() == 0) {
if (!oldYield || oldYield.getNumOperands() != 0)
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield");
}
else if (!inParallelOp) {
return computeBatchOp.emitOpError(
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
}
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
SmallVector<Value> batchInputs;
if (!computeBatchOp.getInputs().empty())
@@ -91,9 +118,22 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
SmallVector<unsigned> returnOperandIndices;
if (computeBatchOp.getNumResults() != 0) {
returnOperandIndices.resize(computeBatchOp.getNumResults());
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
if (failed(returnOperandIndex))
return computeBatchOp.emitOpError(
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
returnOperandIndices[resultIndex] = *returnOperandIndex;
}
}
SmallVector<Type> blockArgTypes;
SmallVector<Location> blockArgLocs;
for (BlockArgument arg : oldBlock.getArguments()) {
unsigned inputArgLimit = 1 + computeBatchOp.getWeights().size() + computeBatchOp.getInputs().size();
for (BlockArgument arg : oldBlock.getArguments().take_front(inputArgLimit)) {
blockArgTypes.push_back(arg.getType());
blockArgLocs.push_back(arg.getLoc());
}
@@ -102,38 +142,44 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
IRMapping mapper;
rewriter.setInsertionPointToStart(newBlock);
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) {
auto oldLaneArg = computeBatchOp.getLaneArgument();
if (!oldLaneArg)
return computeBatchOp.emitOpError("expected compute_batch lane block argument before lowering");
mapper.map(*oldLaneArg, coreBatchOp.getLaneArgument());
for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex) {
auto oldWeightArg = computeBatchOp.getWeightArgument(weightIndex);
if (!oldWeightArg)
return computeBatchOp.emitOpError("expected compute_batch weight block arguments before lowering");
mapper.map(*oldWeightArg, coreBatchOp.getWeightArgument(weightIndex));
}
for (unsigned inputIndex = 0; inputIndex < computeBatchOp.getInputs().size(); ++inputIndex) {
auto oldArg = computeBatchOp.getInputArgument(inputIndex);
if (!oldArg)
return computeBatchOp.emitOpError("expected compute_batch input block arguments before lowering");
BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex);
auto newArgType = cast<ShapedType>(newArg.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
auto copied = pim::PimMemCopyHostToDevOp::create(rewriter,
loc,
outputBuffer.getType(),
zeroOffset,
zeroOffset,
outputBuffer,
newArg,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, newArg))
.getOutput();
mapper.map(oldArg, copied);
mapper.map(*oldArg, copied);
}
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
if (auto mapped = mapper.lookupOrNull(capturedTensor))
return mapped;
SmallVector<Value> hostOutputTensors(returnOperandIndices.size());
auto getOrCreateHostOutputTensor = [&](unsigned resultIndex, Location resultLoc) -> Value {
Value& hostOutputTensor = hostOutputTensors[resultIndex];
if (hostOutputTensor)
return hostOutputTensor;
auto capturedType = cast<ShapedType>(capturedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
capturedTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, capturedTensor))
.getOutput();
mapper.map(capturedTensor, copied);
return copied;
hostOutputTensor = outputTensors[returnOperandIndices[resultIndex]](rewriter, resultLoc);
return hostOutputTensor;
};
rewriter.setInsertionPointToEnd(newBlock);
@@ -141,36 +187,37 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
if (isa<spatial::SpatYieldOp>(op))
continue;
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
pim::PimSendBatchOp::create(rewriter,
loc,
mapper.lookup(sendBatchOp.getInput()),
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
sendBatchOp.getTargetCoreIdsAttr());
continue;
}
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
auto firstOutputArg = computeBatchOp.getOutputArgument(0);
if (!firstOutputArg)
return computeBatchOp.emitOpError("expected compute_batch output block arguments before lowering");
for (Operation& nestedOp : parallelOp.getRegion().front()) {
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&nestedOp);
if (!insertSlice)
return parallelOp.emitOpError("expected only tensor.parallel_insert_slice in spat.in_parallel");
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter);
continue;
}
auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
if (!outputArg || outputArg.getOwner() != &oldBlock)
return insertSlice.emitOpError("expected compute_batch output block argument destination");
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
auto received = pim::PimReceiveBatchOp::create(rewriter,
loc,
outputBuffer.getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
receiveBatchOp.getSourceCoreIdsAttr())
.getOutput();
mapper.map(receiveBatchOp.getOutput(), received);
continue;
}
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
if (resultIndex >= returnOperandIndices.size())
return insertSlice.emitOpError("result index out of range while lowering host batch output");
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter);
Value mappedSource = mapper.lookup(insertSlice.getSource());
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
pim::PimMemCopyDevToHostOp::create(rewriter,
insertSlice.getLoc(),
hostTarget.getType(),
hostTargetOffset,
zeroOffset,
hostTarget,
mappedSource,
getTensorSizeInBytesAttr(rewriter, mappedSource));
}
continue;
}
@@ -178,15 +225,20 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
Operation* cloned = rewriter.clone(op, mapper);
auto clonedTensor = cloned->getResult(0);
if (isUsedOnlyAsExplicitHostOperand(toTensorOp.getResult())) {
mapper.map(toTensorOp.getResult(), clonedTensor);
continue;
}
auto clonedType = cast<ShapedType>(clonedTensor.getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
auto copied = pim::PimMemCopyHostToDevOp::create(rewriter,
loc,
outputBuffer.getType(),
zeroOffset,
zeroOffset,
outputBuffer,
clonedTensor,
rewriter.getI32IntegerAttr(0),
rewriter.getI32IntegerAttr(0),
getTensorSizeInBytesAttr(rewriter, clonedTensor))
.getOutput();
mapper.map(toTensorOp.getResult(), copied);
@@ -194,15 +246,18 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
}
}
for (Value operand : op.getOperands()) {
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
continue;
if (isExplicitDevToHostTargetOperand(&op, operandIndex))
continue;
Operation* definingOp = operand.getDefiningOp();
if (definingOp && definingOp->getBlock() == &oldBlock)
continue;
materializeCapturedTensor(operand);
return computeBatchOp.emitOpError(
"expected external tensor communication to be materialized in Spatial before batch lowering");
}
Operation* cloned = rewriter.clone(op, mapper);
@@ -1,10 +0,0 @@
#pragma once
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
namespace onnx_mlir {
mlir::LogicalResult
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -3,17 +3,17 @@ mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
add_public_tablegen_target(SpatialToPimIncGen)
add_pim_library(OMSpatialToPim
Patterns.cpp
SpatialToPimPass.cpp
BatchCoreLoweringPatterns.cpp
ChannelLoweringPatterns.cpp
Cleanup.cpp
Common.cpp
ComputeLikeRegionUtils.cpp
CoreLoweringPatterns.cpp
GlobalTensorMaterialization.cpp
PhaseVerification.cpp
ReturnPathNormalization.cpp
TensorPackingPatterns.cpp
Patterns/ChannelLowering.cpp
Patterns/GlobalTensorMaterialization.cpp
Patterns/TensorPacking.cpp
Patterns/Transpose.cpp
EXCLUDE_FROM_OM_LIBS
@@ -21,7 +21,10 @@ add_pim_library(OMSpatialToPim
SpatialToPimIncGen
LINK_LIBS PUBLIC
MLIRLinalgDialect
MLIRSCFDialect
MLIRSCFUtils
MLIRTransformUtils
MLIRTosaDialect
OMCompilerOptions
OMPimCommon
@@ -1,9 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
} // namespace onnx_mlir
@@ -1,42 +0,0 @@
#include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult erasePendingOps(SmallVectorImpl<Operation*>& pendingOps, IRRewriter& rewriter) {
while (!pendingOps.empty()) {
bool erasedAnyOp = false;
for (auto it = pendingOps.begin(); it != pendingOps.end();) {
Operation* opToRemove = *it;
if (!opToRemove->use_empty()) {
++it;
continue;
}
rewriter.eraseOp(opToRemove);
it = pendingOps.erase(it);
erasedAnyOp = true;
}
if (erasedAnyOp)
continue;
for (Operation* opToRemove : pendingOps) {
InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation");
diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)";
for (Operation* user : opToRemove->getUsers()) {
bool userPendingRemoval = llvm::is_contained(pendingOps, user);
opToRemove->emitRemark() << "remaining user `" << user->getName() << "`"
<< (userPendingRemoval ? " is also pending removal" : " is not pending removal");
}
}
return failure();
}
return success();
}
} // namespace onnx_mlir
@@ -1,11 +0,0 @@
#pragma once
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
namespace onnx_mlir {
mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl<mlir::Operation*>& pendingOps, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -1,10 +1,8 @@
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
#include <cassert>
#include <cstddef>
#include "Common.hpp"
@@ -13,52 +11,6 @@ using namespace mlir;
namespace onnx_mlir {
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
/*
EXAMPLE RUN:
[1, 10, 3, 4] inputShape
[0, 2, 1, 3] offsets
acc = 1
---
ret = 3
acc = 4
---
ret = 3 + 4 * 1 = 7
acc = 12
---
ret = 7 + 12 * 2 = 31
acc = 120
---
ret = 31 + 120 * 0 = 31
acc = 120
*/
size_t returnValue = 0;
auto sliceOffsets = sliceOp.getStaticOffsets();
auto inputDimSizes = inputShape.getShape();
assert(sliceOffsets.size() == inputDimSizes.size());
size_t accumulatedDimensionSize = 1;
// Reverse iterate the two vectors
for (auto it : reverse(zip(sliceOffsets, inputDimSizes))) {
auto curSliceOffset = std::get<0>(it);
auto curInputDimSize = std::get<1>(it);
returnValue += accumulatedDimensionSize * curSliceOffset;
accumulatedDimensionSize *= curInputDimSize;
}
return returnValue;
}
size_t getShapedTypeSizeInBytes(ShapedType shapedType) {
return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
}
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
}
@@ -6,22 +6,6 @@
namespace onnx_mlir {
/**
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
* its static tensor input.
*
* The static offsets represent the starting position of the slice in each
* dimension, while the static tensor input gives its dimension size.
*
* \param sliceOp The ExtractSliceOp for which the actual offset needs to be
* calculated.
* \param inputShape The ShapedType of the ExtractSliceOp's input tensor
* \return The actual offset of the ExtractSliceOp.
*/
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
template <class T>
@@ -1,3 +1,5 @@
#include <cassert>
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -29,7 +31,18 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
unsigned inputIndex,
Value replacement) {
Block& body = owner->getRegion(0).front();
BlockArgument bodyArgument = body.getArgument(inputIndex);
BlockArgument bodyArgument;
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
auto computeArg = compute.getInputArgument(inputIndex);
assert(computeArg && "expected compute input block argument");
bodyArgument = *computeArg;
}
else {
auto batchArg = cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
assert(batchArg && "expected compute_batch input block argument");
bodyArgument = *batchArg;
}
unsigned bodyArgIndex = bodyArgument.getArgNumber();
rewriter.startOpModification(owner);
bodyArgument.replaceAllUsesWith(replacement);
@@ -37,7 +50,7 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
compute.getInputsMutable().erase(inputIndex);
else
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
body.eraseArgument(inputIndex);
body.eraseArgument(bodyArgIndex);
rewriter.finalizeOpModification(owner);
}
@@ -1,13 +1,15 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -23,11 +25,12 @@ static bool isChannelUseChainOp(Operation* op) {
tensor::ExpandShapeOp,
tensor::CastOp,
tosa::ReshapeOp,
ONNXTransposeOp,
linalg::TransposeOp,
pim::PimTransposeOp>(op);
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
static void
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
@@ -36,7 +39,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
mapping.map(operand, getOrCreateConstantLike(constantFolder, constantOp));
continue;
}
if (!isa<tensor::EmptyOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
@@ -46,8 +54,6 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
}
}
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
@@ -92,7 +98,9 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
return success();
}
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp,
IRRewriter& rewriter,
OperationFolder& constantFolder) {
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
return false;
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
@@ -101,7 +109,7 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
return false;
Block& block = computeOp.getBody().front();
if (block.getNumArguments() != 0)
if (block.getNumArguments() != computeOp.getWeights().size())
return false;
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
@@ -110,8 +118,14 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
rewriter.setInsertionPoint(computeOp);
IRMapping mapping;
for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights())) {
auto weightArg = computeOp.getWeightArgument(weightIndex);
if (!weightArg)
return false;
mapping.map(*weightArg, weight);
}
for (Operation& op : block.without_terminator()) {
cloneMappedHelperOperands(&op, mapping, rewriter);
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
Operation* clonedOp = rewriter.clone(op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
@@ -125,15 +139,12 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
} // namespace
void markOpToRemove(CoreLoweringState& state, Operation* op) {
if (!llvm::is_contained(state.operationsToRemove, op))
state.operationsToRemove.push_back(op);
}
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp,
IRRewriter& rewriter,
OperationFolder& constantFolder) {
Location loc = computeOp->getLoc();
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, constantFolder))
return success();
SmallVector<Operation*> helperChain;
@@ -143,21 +154,24 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
auto& block = computeOp.getRegion().front();
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) {
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp());
if (!receiveOp || blockArg.use_empty())
continue;
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
auto outputType = cast<ShapedType>(blockArg.getType());
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
auto blockArg = computeOp.getInputArgument(inputIndex);
if (!blockArg)
return computeOp.emitOpError("expected compute input block arguments during lowering");
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
if (receiveOp && !blockArg->use_empty()) {
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
auto outputType = cast<ShapedType>(blockArg->getType());
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
Value received = PimReceiveOp::create(
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, *blockArg);
Value received =
PimReceiveOp::create(
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId())
.getOutput();
blockArg.replaceAllUsesWith(received);
markOpToRemove(state, receiveOp);
blockArg->replaceAllUsesWith(received);
markOpToRemove(receiveOp);
continue;
}
}
if (computeOp.getNumResults() != yieldOp.getNumOperands())
@@ -167,9 +181,8 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
if (result.use_empty())
continue;
ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove};
ReturnPathLoweringResult returnPathResult =
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, returnPathState, rewriter);
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, rewriter);
if (returnPathResult == ReturnPathLoweringResult::Failure)
return failure();
if (returnPathResult == ReturnPathLoweringResult::Handled)
@@ -193,15 +206,40 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
if (!computeOp.getWeights().empty())
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
rewriter.setInsertionPointAfter(computeOp);
auto coreOp = PimCoreOp::create(rewriter,
loc,
ValueRange(computeWeights),
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
auto coreOp = PimCoreOp::create(
rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
rewriter.setInsertionPointToStart(&block);
auto& coreOpBlocks = coreOp.getBody().getBlocks();
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
if (!blockArg.use_empty())
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]);
block.eraseArguments(0, block.getNumArguments());
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
auto blockArg = computeOp.getInputArgument(inputIndex);
if (!blockArg)
return computeOp.emitOpError("expected compute input block arguments during input materialization");
if (blockArg->use_empty())
continue;
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
blockArg->replaceAllUsesWith(getOrCreateConstantLike(constantFolder, constantOp));
continue;
}
auto inputType = dyn_cast<ShapedType>(input.getType());
if (!inputType)
return computeOp.emitOpError("expected shaped compute input during pim.core lowering");
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, inputType);
auto copied =
PimMemCopyHostToDevOp::create(rewriter,
loc,
outputBuffer.getType(),
getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0),
getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0),
outputBuffer,
input,
getTensorSizeInBytesAttr(rewriter, input))
.getOutput();
blockArg->replaceAllUsesWith(copied);
}
if (!computeOp.getInputs().empty())
block.eraseArguments(computeOp.getWeights().size(), computeOp.getInputs().size());
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
Block* tempComputeBlock = new Block();
computeOp.getBody().push_back(tempComputeBlock);
@@ -1,21 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
namespace onnx_mlir {
struct CoreLoweringState {
size_t& nextCoreId;
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
};
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
mlir::LogicalResult
lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -1,9 +0,0 @@
#pragma once
#include "mlir/IR/PatternMatch.h"
namespace onnx_mlir {
void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns);
}
@@ -0,0 +1,24 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace raptor {
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
} // namespace raptor
void populateInitialPatterns(RewritePatternSet& patterns) {
raptor::populateWithGenerated(patterns);
populateTransposeLoweringPatterns(patterns);
}
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
raptor::populateWithGenerated(patterns);
populateTransposeLoweringPatterns(patterns);
}
} // namespace onnx_mlir
@@ -8,6 +8,14 @@
namespace onnx_mlir {
void populateInitialPatterns(mlir::RewritePatternSet& patterns);
void populateCoreBodyPatterns(mlir::RewritePatternSet& patterns);
void populateTransposeLoweringPatterns(mlir::RewritePatternSet& patterns);
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns);
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count);
mlir::Value extractPackedChunk(mlir::Value packedValue,
mlir::RankedTensorType chunkType,
@@ -20,7 +28,6 @@ mlir::Value createPackedExtractRowsSlice(spatial::SpatExtractRowsOp extractRowsO
mlir::OpBuilder& builder,
mlir::Location loc);
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
void eraseUnusedTensorPackingOps(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
} // namespace onnx_mlir
@@ -1,7 +1,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -10,17 +10,12 @@ using namespace mlir;
namespace onnx_mlir {
namespace {
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
pim::PimSendOp::create(rewriter,
op.getLoc(),
op.getInput(),
getTensorSizeInBytesAttr(rewriter, op.getInput()),
rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId())));
pim::PimSendOp::create(
rewriter, op.getLoc(), op.getInput(), getTensorSizeInBytesAttr(rewriter, op.getInput()), op.getTargetCoreId());
rewriter.eraseOp(op);
return success();
}
@@ -42,41 +37,7 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
op.getResult().getType(),
outputBuffer,
getTensorSizeInBytesAttr(rewriter, op.getResult()),
rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId())))
.getOutput();
rewriter.replaceOp(op, received);
return success();
}
};
struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
SmallVector<int32_t> targetCoreIds;
targetCoreIds.reserve(op.getTargetCoreIds().size());
for (int32_t targetCoreId : op.getTargetCoreIds())
targetCoreIds.push_back(toPimCoreId(targetCoreId));
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
rewriter.eraseOp(op);
return success();
}
};
struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelReceiveTensorOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
SmallVector<int32_t> sourceCoreIds;
sourceCoreIds.reserve(op.getSourceCoreIds().size());
for (int32_t sourceCoreId : op.getSourceCoreIds())
sourceCoreIds.push_back(toPimCoreId(sourceCoreId));
auto outputType = cast<ShapedType>(op.getOutput().getType());
Value outputBuffer =
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
Value received =
pim::PimReceiveTensorOp::create(
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds))
op.getSourceCoreId())
.getOutput();
rewriter.replaceOp(op, received);
return success();
@@ -125,12 +86,7 @@ struct ConcatLowering : OpRewritePattern<spatial::SpatConcatOp> {
} // namespace
void populateChannelLoweringPatterns(RewritePatternSet& patterns) {
patterns.add<ChannelSendLowering,
ChannelReceiveLowering,
ChannelSendTensorLowering,
ChannelReceiveTensorLowering,
ExtractRowsLowering,
ConcatLowering>(patterns.getContext());
patterns.add<ChannelSendLowering, ChannelReceiveLowering, ExtractRowsLowering, ConcatLowering>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -16,7 +16,7 @@
#include "Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
@@ -76,10 +76,11 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
auto BBArgValue = spatCompute.getInputArgument(*inputIndex);
if (!BBArgValue)
return failure();
if (BBArgValue.use_empty())
if (BBArgValue->use_empty())
continue;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
@@ -89,16 +90,17 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
}
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatToExtract[spatCompute.getOperation()]);
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
}
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
auto BBArgValue = spatComputeBatch.getInputArgument(*inputIndex);
if (!BBArgValue)
return failure();
if (BBArgValue.use_empty())
if (BBArgValue->use_empty())
continue;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
@@ -108,7 +110,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
}
replaceAndEraseDirectComputeLikeInput(
rewriter, spatComputeBatch.getOperation(), BBArgIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
rewriter, spatComputeBatch.getOperation(), *inputIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
}
else {
{
@@ -143,170 +145,6 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
}
};
// Turns runtime constants consumed by compute regions into private globals and local loads.
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
Location loc = constantOp.getLoc();
if (hasWeightAlways(constantOp))
return failure();
if (!isa<func::FuncOp>(constantOp->getParentOp()))
return failure();
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
if (isa<spatial::SpatCompute>(op))
return false;
if (isa<func::FuncOp>(op->getParentOp()))
return true;
return false;
}))
return failure();
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
if (constRankedTensorType) {
mlir::MemRefType memRefType =
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
auto globalOp = createPrivateMemrefGlobalWithUniqueName(rewriter,
loc,
constantOp->getParentOfType<ModuleOp>(),
"const",
memRefType,
constantOp.getValueAttr(),
rewriter.getUnitAttr());
std::string argName = globalOp.getSymName().str();
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatComputeToConst[spatCompute.getOperation()]);
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
replaceAndEraseDirectComputeLikeInput(rewriter,
spatComputeBatch.getOperation(),
BBArgIndex,
mapSpatComputeToConst[spatComputeBatch.getOperation()]);
}
else {
{
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatCompute.getOperation());
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
rewriter.finalizeOpModification(spatCompute.getOperation());
}
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
auto toTensor = bufferization::ToTensorOp::create(
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
}
rewriter.startOpModification(spatComputeBatch.getOperation());
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
}
}
}
}
}
else if (constantOp.getType().isIntOrIndexOrFloat()) {
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
auto constUsers = constUses.getOwner();
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
replaceAndEraseDirectComputeLikeInput(
rewriter, spatCompute.getOperation(), BBArgIndex, newConst->getResult(0));
}
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
if (!inputIndex)
return failure();
auto BBArgIndex = *inputIndex;
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
replaceAndEraseDirectComputeLikeInput(
rewriter, spatComputeBatch.getOperation(), BBArgIndex, newConst->getResult(0));
}
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
if (!mapSpatComputeToConst.contains(parent)) {
rewriter.setInsertionPoint(&parent.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
}
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
}
else {
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
assert(batchParent && "Global Constant used direcly not within a compute");
if (!mapSpatComputeToConst.contains(batchParent.getOperation())) {
rewriter.setInsertionPoint(&batchParent.getBody().front().front());
auto newConst = rewriter.clone(*constantOp);
mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)});
}
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
}
}
}
if (constantOp->use_empty())
rewriter.eraseOp(constantOp);
return success();
}
};
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
@@ -383,8 +221,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
} // namespace
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
patterns.getContext());
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -1,4 +1,4 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
@@ -0,0 +1,38 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir;
namespace onnx_mlir {
namespace {
struct LinalgTransposeToPim final : OpRewritePattern<linalg::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter& rewriter) const override {
SmallVector<Attribute> permutationAttrs;
permutationAttrs.reserve(transposeOp.getPermutation().size());
for (int64_t dim : transposeOp.getPermutation())
permutationAttrs.push_back(rewriter.getI64IntegerAttr(dim));
auto permutation = rewriter.getArrayAttr(permutationAttrs);
auto pimTranspose = pim::PimTransposeOp::create(rewriter,
transposeOp.getLoc(),
TypeRange {transposeOp->getResult(0).getType()},
transposeOp.getInput(),
permutation,
transposeOp.getInit());
rewriter.replaceOp(transposeOp, pimTranspose.getOutput());
return success();
}
};
} // namespace
void populateTransposeLoweringPatterns(RewritePatternSet& patterns) {
patterns.add<LinalgTransposeToPim>(patterns.getContext());
}
} // namespace onnx_mlir
@@ -1,20 +0,0 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
using namespace mlir;
namespace onnx_mlir {
LogicalResult verifySpatialToPimBoundary(ModuleOp moduleOp) {
bool hasFailure = false;
moduleOp.walk([&](Operation* op) {
if (op->getDialect()->getNamespace() != "spat")
return;
op->emitError("illegal Spatial operation remains after Spatial-to-PIM lowering");
hasFailure = true;
});
return success(!hasFailure);
}
} // namespace onnx_mlir
@@ -1,9 +0,0 @@
#pragma once
#include "mlir/IR/BuiltinOps.h"
namespace onnx_mlir {
mlir::LogicalResult verifySpatialToPimBoundary(mlir::ModuleOp moduleOp);
} // namespace onnx_mlir
@@ -1,15 +1,18 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/FoldUtils.h"
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -38,15 +41,10 @@ static bool isReturnHelperChainOp(Operation* op) {
tensor::ExpandShapeOp,
tensor::CastOp,
tosa::ReshapeOp,
ONNXTransposeOp,
linalg::TransposeOp,
pim::PimTransposeOp>(op);
}
static void markOpToRemove(ReturnPathState& state, Operation* op) {
if (!llvm::is_contained(state.operationsToRemove, op))
state.operationsToRemove.push_back(op);
}
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
std::string name = baseName.str();
unsigned suffix = 0;
@@ -279,11 +277,10 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
continue;
}
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op)) {
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(op)) {
SmallVector<int64_t> nextIndices(currentIndices.size());
SmallVector<int64_t> nextShape(currentShape.size());
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange<IntegerAttr>())) {
int64_t sourceIndex = attr.getInt();
for (auto [destIndex, sourceIndex] : llvm::enumerate(transposeOp.getPermutation())) {
nextIndices[destIndex] = currentIndices[sourceIndex];
nextShape[destIndex] = currentShape[sourceIndex];
}
@@ -318,7 +315,8 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
return success();
}
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
static void
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
for (Value operand : op->getOperands()) {
if (mapping.lookupOrNull(operand))
continue;
@@ -327,7 +325,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
if (!definingOp)
continue;
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
mapping.map(operand, getOrCreateConstantLike(constantFolder, constantOp));
continue;
}
if (!isa<tensor::EmptyOp>(definingOp))
continue;
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
@@ -337,15 +340,18 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
}
}
static void
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) {
static void cloneHelperChain(Value sourceValue,
ArrayRef<Operation*> helperChain,
IRRewriter& rewriter,
OperationFolder& constantFolder,
Value& clonedValue) {
IRMapping mapping;
mapping.map(sourceValue, sourceValue);
clonedValue = sourceValue;
rewriter.setInsertionPointAfterValue(sourceValue);
for (Operation* op : helperChain) {
cloneMappedHelperOperands(op, mapping, rewriter);
cloneMappedHelperOperands(op, mapping, rewriter, constantFolder);
Operation* clonedOp = rewriter.clone(*op, mapping);
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
mapping.map(originalResult, newResult);
@@ -360,23 +366,26 @@ static Value emitHostCopy(IRRewriter& rewriter,
Value sourceValue,
int32_t hostTargetOffset,
int32_t deviceSourceOffset,
int32_t sizeInBytes) {
int32_t sizeInBytes,
OperationFolder& constantFolder) {
Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp();
assert(anchorOp && "expected a concrete op anchor for return-path host copy constants");
Value hostTargetOffsetValue = getOrCreateIndexConstant(constantFolder, anchorOp, hostTargetOffset);
Value deviceSourceOffsetValue = getOrCreateIndexConstant(constantFolder, anchorOp, deviceSourceOffset);
return PimMemCopyDevToHostOp::create(rewriter,
loc,
outputTensor.getType(),
hostTargetOffsetValue,
deviceSourceOffsetValue,
outputTensor,
sourceValue,
rewriter.getI32IntegerAttr(hostTargetOffset),
rewriter.getI32IntegerAttr(deviceSourceOffset),
rewriter.getI32IntegerAttr(sizeInBytes))
.getOutput();
}
} // namespace
void addReturnOutputBuffers(func::ReturnOp returnOp,
IRRewriter& rewriter,
SmallVectorImpl<OutputTensorFactory>& outputTensors) {
void raptor::SpatialToPimPass::addReturnOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
outputTensors.reserve(returnOp->getNumOperands());
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
Value currentReturnValue = returnValue;
@@ -411,70 +420,85 @@ void addReturnOutputBuffers(func::ReturnOp returnOp,
}
}
ReturnPathLoweringResult lowerComputeResultReturnPath(
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
Location loc = computeOp->getLoc();
auto yieldType = cast<TensorType>(yieldValue.getType());
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerProducedValueReturnPath(
Operation* producerOp, Value producedValue, Value storedValue, IRRewriter& rewriter) {
Location loc = producerOp->getLoc();
OperationFolder constantFolder(producerOp->getContext());
auto storedTensorType = cast<TensorType>(storedValue.getType());
if (auto returnUse = analyzeReturnUse(result)) {
Value storedValue = yieldValue;
cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue);
if (auto returnUse = analyzeReturnUse(producedValue)) {
Value currentStoredValue = storedValue;
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
for (Operation* op : returnUse->helperChain)
markOpToRemove(state, op);
markOpToRemove(op);
auto storedType = cast<ShapedType>(storedValue.getType());
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
if (auto storedOp = storedValue.getDefiningOp())
auto storedType = cast<ShapedType>(currentStoredValue.getType());
size_t elementSize = getElementTypeSizeInBytes(storedType.getElementType());
if (auto storedOp = currentStoredValue.getDefiningOp())
rewriter.setInsertionPointAfter(storedOp);
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc);
emitHostCopy(
rewriter, loc, outputTensor, storedValue, 0, 0, static_cast<int32_t>(storedType.getNumElements() * elementSize));
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
emitHostCopy(rewriter,
loc,
outputTensor,
currentStoredValue,
0,
0,
static_cast<int32_t>(storedType.getNumElements() * elementSize),
constantFolder);
return ReturnPathLoweringResult::Handled;
}
auto resultUses = result.getUses();
auto resultUses = producedValue.getUses();
if (rangeLength(resultUses) == 1) {
OpOperand& resultUse = *resultUses.begin();
Operation* resultUser = resultUse.getOwner();
if (isa<func::ReturnOp>(resultUser)) {
size_t resultIndexInReturn = resultUse.getOperandNumber();
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc);
emitHostCopy(
rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast<int32_t>(yieldType.getNumElements() * elementSize));
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
emitHostCopy(rewriter,
loc,
outputTensor,
storedValue,
0,
0,
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
constantFolder);
return ReturnPathLoweringResult::Handled;
}
}
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
for (Operation* concatOp : concatReturnUse->concatChain)
markOpToRemove(state, concatOp);
markOpToRemove(concatOp);
if (concatReturnUse->helperChain.empty()) {
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType());
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
emitHostCopy(rewriter,
loc,
outputTensor,
yieldValue,
storedValue,
static_cast<int32_t>(flatOffset * elementSize),
0,
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
constantFolder);
return ReturnPathLoweringResult::Handled;
}
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
auto storedType = dyn_cast<RankedTensorType>(storedValue.getType());
if (!storedType) {
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
producerOp->emitOpError(
"has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
return ReturnPathLoweringResult::Failure;
}
rewriter.setInsertionPointAfterValue(yieldValue);
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
rewriter.setInsertionPointAfterValue(storedValue);
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
auto outputType = cast<ShapedType>(outputTensor.getType());
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
@@ -484,7 +508,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
SmallVector<int64_t> destinationIndices;
if (failed(mapIndicesThroughHelperChain(
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
producerOp->emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
return ReturnPathLoweringResult::Failure;
}
@@ -503,7 +527,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
auto scalarTensorType =
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
auto elementSlice = tensor::ExtractSliceOp::create(
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
rewriter, loc, scalarTensorType, storedValue, extractOffsets, extractSizes, extractStrides);
rewriter.setInsertionPointAfter(elementSlice);
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
@@ -513,7 +537,8 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
elementSlice.getResult(),
static_cast<int32_t>(destinationFlatOffset * elementSize),
0,
static_cast<int32_t>(elementSize));
static_cast<int32_t>(elementSize),
constantFolder);
}
return ReturnPathLoweringResult::Handled;
}
@@ -521,7 +546,12 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
return ReturnPathLoweringResult::NotReturnPath;
}
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) {
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
}
void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
if (!op)
return;
@@ -538,13 +568,13 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
if (isReturnHelperChainOp(op)) {
Value source = op->getOperand(0);
markOpToRemove(state, op);
markOpToRemove(op);
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
markOpToRemove(state, computeOp);
markOpToRemove(computeOp);
if (!computeOp.getInputs().empty())
for (Value input : computeOp.getInputs())
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
@@ -552,23 +582,29 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
}
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
markOpToRemove(state, concatOp);
markOpToRemove(concatOp);
for (Value operand : concatOp.getOperands())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
markOpToRemove(state, concatOp);
markOpToRemove(concatOp);
for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
markOpToRemove(state, concatOp);
markOpToRemove(concatOp);
for (Value operand : concatOp.getInputs())
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
return;
}
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(op)) {
markOpToRemove(receiveOp);
return;
}
};
@@ -578,7 +614,7 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
size_t orderWithinReturn = it.index();
Operation* returnOperand = it.value().getDefiningOp();
rewriter.setInsertionPoint(returnOp);
Value outputTensor = state.outputTensors[orderWithinReturn](rewriter, loc);
Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc);
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
}

Some files were not shown because too many files have changed in this diff Show More