Compare commits
56 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b678e55d3c | |||
| ab63498f3f | |||
| 7c3943bd06 | |||
| c0238c0d06 | |||
| ff36729140 | |||
| cf93caecd5 | |||
| 2d5b03c08f | |||
| a41f694cf0 | |||
| 8bb0babf1b | |||
| 819d8af0f7 | |||
| 832bd7f1f7 | |||
| 82b44a6387 | |||
| 7fcc765d6e | |||
| f34698a2b6 | |||
| 1ab489fe0a | |||
| cbf7b235f1 | |||
| 00414dd1d9 | |||
| 783dffe553 | |||
| 874a2f53e6 | |||
| 4bdaa57656 | |||
| 1a5d7d2a3f | |||
| 013ae0ac2a | |||
| c6b02af7a9 | |||
| d2048bd394 | |||
| 158f0f0c54 | |||
| 532cac8246 | |||
| d609e84054 | |||
| addfc8a86e | |||
| 0f240af271 | |||
| bdc4ca33f3 | |||
| b79c333c6c | |||
| eea9261c7b | |||
| e8a08f6dd0 | |||
| 4855a2e105 | |||
| 3a7a832198 | |||
| 48ca6bd28d | |||
| f595cc6ffd | |||
| c734f1b37e | |||
| b79ce8eeaa | |||
| 76a37e198f | |||
| 7f3c7464b4 | |||
| c77ffa9c56 | |||
| 495186503c | |||
| 2c1da813b5 | |||
| 8337a11ce9 | |||
| d136136d22 | |||
| 074eb183c7 | |||
| 43ed3914b8 | |||
| 6aaf1c0870 | |||
| fe35b3ed43 | |||
| 90a9339686 | |||
| a50e77ff38 | |||
| f56c4159b5 | |||
| 5637c861b4 | |||
| 94157a8404 | |||
| 68a3521978 |
+2
-5
@@ -4,14 +4,11 @@
|
||||
|
||||
.claude
|
||||
.codex
|
||||
AGENTS.md
|
||||
|
||||
CMakeUserPresets.json
|
||||
|
||||
build
|
||||
build_release
|
||||
cmake-build-debug
|
||||
cmake-build-release
|
||||
build_*
|
||||
compile.sh
|
||||
pimcomp_utils/*
|
||||
|
||||
**/__*
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
- Always read the full README.md before doing anything.
|
||||
- Build commands:
|
||||
- `cmake --build ./build_release`
|
||||
- `cmake --build ./build_debug`
|
||||
- Never use `ninja` directly: it bypasses cmake's configuration and invalidates the build cache.
|
||||
- Always tries the release version build first and ask before building with the debug version
|
||||
|
||||
# Code changes
|
||||
|
||||
- Keep changes minimal and localized to the relevant parts of the code.
|
||||
- Preserve the existing naming conventions and coding style used in the surrounding code.
|
||||
- Keep code easy to read, well organized, and suitable for future extensibility. A function must not be longer than
|
||||
200/250 lines for readability and cognitive complexity.
|
||||
- Prefer clear naming and structure over comments. Add comments only when they materially improve clarity.
|
||||
- Do not rename symbols, move files, or restructure modules unless that is necessary for the requested change.
|
||||
|
||||
# Working style
|
||||
|
||||
- Infer style and conventions from the existing code before introducing new patterns.
|
||||
- When several implementation options are possible, prefer the simplest one that fits the current architecture and
|
||||
minimizes churn.
|
||||
- Avoid broad refactors unless I explicitly ask for them.
|
||||
|
||||
# Responses
|
||||
|
||||
- When showing code in chat, make it easy to copy-paste into the codebase.
|
||||
- Keep outputs focused on the changed parts.
|
||||
- At the end of the response, briefly list any bad practices, mistakes, or cleaner alternatives you noticed, separate
|
||||
from the main solution.
|
||||
|
||||
# Guidelines
|
||||
|
||||
## 1. Think Before Coding
|
||||
|
||||
**Don't assume. Don't hide confusion. Surface tradeoffs.**
|
||||
|
||||
Before implementing:
|
||||
|
||||
- State your assumptions explicitly. If uncertain, ask.
|
||||
- If multiple interpretations exist, present them - don't pick silently.
|
||||
- If a simpler approach exists, say so. Push back when warranted.
|
||||
- If something is unclear, stop. Name what's confusing. Ask.
|
||||
|
||||
## 2. Simplicity First
|
||||
|
||||
**Minimum code that solves the problem. Nothing speculative.**
|
||||
|
||||
- No features beyond what was asked.
|
||||
- No error handling for impossible scenarios.
|
||||
- If you write 200 lines and it could be 50, rewrite it.
|
||||
|
||||
Ask yourself: "Would a senior engineer say this is overcomplicated?" If yes, simplify.
|
||||
|
||||
## 3. Surgical Changes
|
||||
|
||||
**Touch only what you must. Clean up only your own mess.**
|
||||
|
||||
When editing existing code:
|
||||
|
||||
- Don't "improve" adjacent code, comments, or formatting.
|
||||
- Don't refactor things that aren't broken.
|
||||
- Match existing style, even if you'd do it differently.
|
||||
- If you notice unrelated dead code, mention it - don't delete it.
|
||||
|
||||
When your changes create orphans:
|
||||
|
||||
- Remove imports/variables/functions that YOUR changes made unused.
|
||||
- Don't remove pre-existing dead code unless asked, but mention it.
|
||||
|
||||
The test: Every changed line should trace directly to the user's request.
|
||||
|
||||
## 4. Goal-Driven Execution
|
||||
|
||||
**Define success criteria. Loop until verified.**
|
||||
|
||||
Transform tasks into verifiable goals:
|
||||
|
||||
- "Add validation" → "Write tests for invalid inputs, then make them pass"
|
||||
- "Fix the bug" → "Write a test that reproduces it, then make it pass"
|
||||
- "Refactor X" → "Ensure tests pass before and after"
|
||||
|
||||
For multi-step tasks, state a brief plan:
|
||||
|
||||
```
|
||||
1. [Step] → verify: [check]
|
||||
2. [Step] → verify: [check]
|
||||
3. [Step] → verify: [check]
|
||||
```
|
||||
|
||||
Strong success criteria let you loop independently. Weak criteria ("make it work") require constant clarification.
|
||||
|
||||
---
|
||||
+85
-17
@@ -3,31 +3,99 @@ cmake_minimum_required(VERSION 3.20.0)
|
||||
|
||||
project(raptor)
|
||||
|
||||
# Add symlink to PIM as accelerator in onnx-mlir
|
||||
function(raptor_ensure_symlink link_path target_path)
|
||||
get_filename_component(link_parent "${link_path}" DIRECTORY)
|
||||
# Materialize a CMake shim directory
|
||||
function(raptor_write_external_cmake_shim shim_dir external_source_dir description)
|
||||
get_filename_component(real_external_source_dir "${external_source_dir}" REALPATH)
|
||||
file(RELATIVE_PATH relative_external_source_dir "${shim_dir}" "${real_external_source_dir}")
|
||||
|
||||
if(NOT EXISTS "${link_parent}")
|
||||
message(FATAL_ERROR "Directory not found: ${link_parent}")
|
||||
endif()
|
||||
if (NOT EXISTS "${real_external_source_dir}/CMakeLists.txt")
|
||||
message(FATAL_ERROR
|
||||
"External CMake source directory not found or missing CMakeLists.txt:\n"
|
||||
" ${real_external_source_dir}"
|
||||
)
|
||||
endif ()
|
||||
|
||||
if (IS_SYMLINK "${shim_dir}")
|
||||
message(STATUS "Removing old full-directory symlink: ${shim_dir}")
|
||||
file(REMOVE "${shim_dir}")
|
||||
endif ()
|
||||
|
||||
if (EXISTS "${shim_dir}" AND NOT IS_DIRECTORY "${shim_dir}")
|
||||
message(FATAL_ERROR "Expected directory or absent path, got file: ${shim_dir}")
|
||||
endif ()
|
||||
|
||||
file(MAKE_DIRECTORY "${shim_dir}")
|
||||
|
||||
set(shim_file "${shim_dir}/CMakeLists.txt")
|
||||
set(shim_contents
|
||||
"get_filename_component(raptor_external_source_dir
|
||||
\"\${CMAKE_CURRENT_LIST_DIR}/${relative_external_source_dir}\"
|
||||
REALPATH
|
||||
)
|
||||
add_subdirectory(
|
||||
\"\${raptor_external_source_dir}\"
|
||||
\"\${CMAKE_CURRENT_BINARY_DIR}/raptor-external\"
|
||||
)
|
||||
if (DEFINED PIM_ENABLED)
|
||||
set(PIM_ENABLED \"\${PIM_ENABLED}\" PARENT_SCOPE)
|
||||
endif ()
|
||||
"
|
||||
)
|
||||
|
||||
if (EXISTS "${shim_file}")
|
||||
file(READ "${shim_file}" old_contents)
|
||||
else ()
|
||||
set(old_contents "")
|
||||
endif ()
|
||||
|
||||
if (NOT old_contents STREQUAL shim_contents)
|
||||
file(WRITE "${shim_file}" "${shim_contents}")
|
||||
message(STATUS "Wrote CMake shim for ${description}: ${shim_file}")
|
||||
else ()
|
||||
message(STATUS "CMake shim already up to date for ${description}")
|
||||
endif ()
|
||||
|
||||
# Mirror the external tree's first-level entries into the shim directory
|
||||
# so legacy includes like src/Accelerators/PIM/Compiler/... keep working.
|
||||
file(GLOB children RELATIVE "${real_external_source_dir}" "${real_external_source_dir}/*")
|
||||
|
||||
foreach (child IN LISTS children)
|
||||
if (child STREQUAL "CMakeLists.txt")
|
||||
continue()
|
||||
endif ()
|
||||
|
||||
set(real_child "${real_external_source_dir}/${child}")
|
||||
set(shim_child "${shim_dir}/${child}")
|
||||
|
||||
if (IS_SYMLINK "${shim_child}")
|
||||
file(READ_SYMLINK "${shim_child}" existing_link_target)
|
||||
if (existing_link_target STREQUAL real_child)
|
||||
continue()
|
||||
endif ()
|
||||
file(REMOVE_RECURSE "${shim_child}")
|
||||
elseif (EXISTS "${shim_child}")
|
||||
# Do not delete real files/directories. This protects the generated shim.
|
||||
continue()
|
||||
endif ()
|
||||
|
||||
if(NOT EXISTS "${link_path}")
|
||||
message(STATUS "Creating symlink ${link_path} -> ${target_path}")
|
||||
file(CREATE_LINK
|
||||
"${target_path}"
|
||||
"${link_path}"
|
||||
"${real_child}"
|
||||
"${shim_child}"
|
||||
SYMBOLIC
|
||||
)
|
||||
endif()
|
||||
endforeach ()
|
||||
endfunction()
|
||||
|
||||
raptor_ensure_symlink(
|
||||
raptor_write_external_cmake_shim(
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/src/Accelerators/PIM"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/src/PIM"
|
||||
"PIM accelerator"
|
||||
)
|
||||
raptor_ensure_symlink(
|
||||
|
||||
raptor_write_external_cmake_shim(
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/onnx-mlir/test/accelerators/PIM"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}/test/PIM"
|
||||
"PIM accelerator tests"
|
||||
)
|
||||
|
||||
# Patch onnx-mlir sources for PIM accelerator support.
|
||||
@@ -38,21 +106,21 @@ function(raptor_apply_patch file_path anchor replacement description)
|
||||
|
||||
# Already applied – replacement text is present
|
||||
string(FIND "${contents}" "${replacement}" already_applied_pos)
|
||||
if(NOT already_applied_pos EQUAL -1)
|
||||
if (NOT already_applied_pos EQUAL -1)
|
||||
message(STATUS "Patch already applied: ${description}")
|
||||
return()
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
# Anchor must exist for the patch to be applicable
|
||||
string(FIND "${contents}" "${anchor}" anchor_pos)
|
||||
if(anchor_pos EQUAL -1)
|
||||
if (anchor_pos EQUAL -1)
|
||||
message(FATAL_ERROR
|
||||
"Patch anchor not found – onnx-mlir may have changed.\n"
|
||||
" Patch : ${description}\n"
|
||||
" File : ${file_path}\n"
|
||||
" Anchor: ${anchor}"
|
||||
)
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
string(REPLACE "${anchor}" "${replacement}" patched "${contents}")
|
||||
file(WRITE "${file_path}" "${patched}")
|
||||
|
||||
@@ -1,226 +1,321 @@
|
||||
# Raptor
|
||||
|
||||
Raptor is a domain-specific MLIR compiler for neural networks (ONNX format)
|
||||
targeting in-memory computing / processing-in-memory (PIM) architectures.
|
||||
It progressively lowers ONNX-MLIR through a set of MLIR dialects down to
|
||||
target-specific artifacts (currently JSON code for the `pimsim-nn` simulator).
|
||||
Raptor is a domain-specific MLIR compiler for neural networks in ONNX format,
|
||||
targeting in-memory computing / processing-in-memory (PIM) architectures. It
|
||||
extends ONNX-MLIR with a PIM accelerator and progressively lowers ONNX-MLIR
|
||||
through custom MLIR dialects to simulator artifacts.
|
||||
|
||||
The current target is the PIM simulator stack under `backend-simulators/pim`.
|
||||
Raptor emits binary per-core `.pim` instruction files by default, plus
|
||||
`memory.bin`, `config.json`, and weight binaries. It can also emit per-core JSON
|
||||
instruction files with `--pim-emit-json`.
|
||||
|
||||
## Overview
|
||||
|
||||
PIM architectures perform most of the computation directly in memory.
|
||||
Raptor's first supported target is `pimsim-nn`, which simulates a chip with:
|
||||
- a shared host memory,
|
||||
- a number of cores that do most of the computation directly in their memory
|
||||
(vector ops, vmm/mvm on ReRAM crossbars),
|
||||
- no branching instructions (branchless architecture) and no hardware loop
|
||||
support — any repeated work (e.g. convolutions) must be unrolled into
|
||||
explicit per-iteration instructions.
|
||||
PIM architectures perform most computation directly in memory. The supported
|
||||
target models a chip with:
|
||||
- shared host memory,
|
||||
- multiple PIM cores,
|
||||
- ReRAM crossbars for vector-matrix / matrix-vector work,
|
||||
- explicit communication between cores,
|
||||
- no hardware branch or loop support in emitted simulator code.
|
||||
|
||||
Because of this, the amount of emitted instructions explodes quickly and the
|
||||
compiler must optimize aggressively at every stage to keep compilation
|
||||
tractable.
|
||||
|
||||
A second target, `PulPim`, is planned for an accelerator with RISC-V cores
|
||||
each carrying its own in-memory computing unit and crossbars. It will live in
|
||||
a dedicated dialect (future work).
|
||||
Because repeated work such as convolutions is eventually made explicit, emitted
|
||||
instruction counts can grow quickly. Most compiler work therefore focuses on
|
||||
lowering, scheduling, memory layout, and code-generation optimizations.
|
||||
|
||||
### Targets and simulators
|
||||
|
||||
`pimsim-nn` (under `backend-simulators/pim/pimsim-nn`) is used for
|
||||
**performance** estimates (latency, energy), but does not functionally execute
|
||||
the JSON code it consumes. To validate the numerical correctness of the JSON
|
||||
code produced by Raptor (or, for comparison, by the `pimcomp` compiler), we use
|
||||
a Rust simulator we maintain in-tree at
|
||||
`backend-simulators/pim/pim-simulator`.
|
||||
- `backend-simulators/pim/pim-simulator` is the in-tree Rust functional
|
||||
simulator used by validation. It reads Raptor's `pim/` artifact directory and
|
||||
compares simulator output against native ONNX-MLIR execution.
|
||||
- `backend-simulators/pim/pimsim-nn` is the performance simulator submodule.
|
||||
The helper scripts in `pimcomp_utils/` are for comparison with PIMCOMP-NN and
|
||||
contain local paths; treat them as local utilities, not portable workflows.
|
||||
|
||||
## Compilation pipeline
|
||||
|
||||
The PIM-related sources live under `src/PIM` and the tests under `test/PIM`.
|
||||
When working on this codebase, most changes should stay confined to those
|
||||
trees (you only need to look outside, e.g. at `onnx-mlir` or `llvm`, for
|
||||
framework-level details).
|
||||
The PIM sources live under `src/PIM` and tests under `test/PIM`. CMake exposes
|
||||
them to ONNX-MLIR through generated shim directories under
|
||||
`onnx-mlir/src/Accelerators/PIM` and `onnx-mlir/test/accelerators/PIM`.
|
||||
|
||||
High-level lowering flow:
|
||||
|
||||
```
|
||||
ONNX-MLIR ──► Spatial ──► Pim (tensor) ──► Pim (bufferized) ──► PIM JSON
|
||||
ONNX-MLIR -> Spatial -> Pim (tensor) -> Pim (bufferized) -> PIM artifacts
|
||||
```
|
||||
|
||||
1. **ONNX → Spatial** (`src/PIM/Conversion/ONNXToSpatial`).
|
||||
Lowers ONNX ops into the `spat` dialect (`src/PIM/Dialect/Spatial`).
|
||||
Spatial models a high-level spatial in-memory accelerator: vmm/mvm
|
||||
operations are accelerated by storing a constant RHS matrix into a
|
||||
crossbar. Crossbars cannot be re-programmed during execution, have a
|
||||
limited fixed size, and there is a limited number of them per core.
|
||||
Conversion patterns are split by op family under
|
||||
`Conversion/ONNXToSpatial/Patterns/{Math,NN,Tensor}` (Conv, Gemm, MatMul,
|
||||
Elementwise, ReduceMean, Pool, Relu, Sigmoid, Softmax, Concat, Gather,
|
||||
Reshape, Resize, Split).
|
||||
1. **ONNX -> Spatial** (`src/PIM/Conversion/ONNXToSpatial`).
|
||||
Lowers supported ONNX ops into the `spat` dialect
|
||||
(`src/PIM/Dialect/Spatial`). Conversion patterns are split by op family under
|
||||
`Patterns/{Math,NN,Tensor}` and currently cover Conv, Gemm, MatMul,
|
||||
elementwise Add/Mul/Div, ReduceMean, pooling, Relu, Sigmoid, Softmax,
|
||||
Concat, Gather, Reshape, Resize, and Split.
|
||||
|
||||
2. **Spatial → Pim** (`src/PIM/Conversion/SpatialToPim`).
|
||||
Lowers Spatial to the `pim` dialect (`src/PIM/Dialect/Pim`), which
|
||||
materializes PIM cores (`pim.core`), inter-core communication
|
||||
(`pim.send` / `pim.receive`), halts, and crossbar-level operations.
|
||||
2. **Merge compute nodes**
|
||||
(`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`).
|
||||
Builds a compute graph, schedules it with the PEFT scheduler, and materializes
|
||||
the merge schedule into Spatial IR. Supporting scheduling code lives under
|
||||
`MergeComputeNodes/Scheduling`.
|
||||
|
||||
3. **Merge compute nodes** (`src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes`).
|
||||
A DCP-inspired heuristic (Dynamic Critical Path — see the original
|
||||
scheduling paper by Kwok & Ahmad,
|
||||
[DCP-eScience2007](https://clouds.cis.unimelb.edu.au/papers/DCP-eScience2007.pdf))
|
||||
that coarsens the virtual node graph and decides how to group compute
|
||||
nodes onto cores. Our implementation is only DCP-*inspired*: it is a
|
||||
heuristic with different assumptions from the paper (different cost
|
||||
model, constraints from crossbar capacity / core resources, and a
|
||||
windowed coarsening loop instead of full-graph reprioritization). The
|
||||
`dcp-critical-window-size` option controls how many lowest-slack virtual
|
||||
nodes each coarsening iteration considers (0 = legacy full-graph
|
||||
analysis). Related sources: `DCPGraph/DCPAnalysis.cpp`, `Graph.cpp/.hpp`,
|
||||
`MergeComputeNodesPass.cpp`.
|
||||
3. **Spatial -> Pim** (`src/PIM/Conversion/SpatialToPim`).
|
||||
Lowers Spatial operations to the `pim` dialect (`src/PIM/Dialect/Pim`),
|
||||
including `pim.core`, `pim.core_batch`, communication, tensor packing, global
|
||||
tensor materialization, and return-path normalization.
|
||||
|
||||
4. **Bufferization** (`src/PIM/Dialect/Pim/Transforms/Bufferization`).
|
||||
Converts tensor-semantics PIM IR into memref-semantics PIM IR using the
|
||||
standard MLIR `BufferizableOpInterface` machinery
|
||||
(`OpBufferizationInterfaces.*`, `PimBufferization.td`).
|
||||
Converts tensor-semantics PIM IR into memref-semantics PIM IR using MLIR's
|
||||
bufferization interfaces.
|
||||
|
||||
5. **Static memory coalescing** (`src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing`).
|
||||
Conservatively reuses same-typed local memref allocations inside PIM cores
|
||||
after bufferization and before code generation.
|
||||
5. **Static memory coalescing**
|
||||
(`src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing`).
|
||||
Reuses compatible local memref allocations inside PIM cores before codegen.
|
||||
|
||||
6. **PIM code generation** (`src/PIM/Pass/PimCodegen`):
|
||||
- `HostConstantFolding` — folds host-side constants.
|
||||
- `MaterializeHostConstantsPass` — materializes the remaining host
|
||||
constants for emission.
|
||||
- `VerificationPass` — checks invariants before emission.
|
||||
- `EmitPimJsonPass` — emits the final PIM JSON consumed by `pimsim-nn`
|
||||
and `pim-simulator`.
|
||||
6. **PIM code generation** (`src/PIM/Pass/PimCodegen` and
|
||||
`src/PIM/Compiler`).
|
||||
Folds host constants, materializes remaining host constants, verifies PIM IR,
|
||||
emits `.pim` core files, writes weights, and writes `memory.bin` /
|
||||
`config.json`.
|
||||
|
||||
Supporting pieces:
|
||||
- `src/PIM/Compiler` — PIM-specific compiler options (crossbar size/count,
|
||||
core count, DCP window, experimental conv impl, concat error handling, …)
|
||||
and `PimCodeGen` entry points.
|
||||
- `src/PIM/Common` — shared utilities (`PimCommon`, `LabeledList`).
|
||||
- `src/PIM/Pass` — auxiliary passes (`MessagePass`, `CountInstructionPass`)
|
||||
and the `PIMPasses.h` registry used by `PimAccelerator`.
|
||||
- `src/PIM/PimAccelerator.{cpp,hpp}` — accelerator entry point: registers
|
||||
dialects, passes, and plugs Raptor into the ONNX-MLIR driver.
|
||||
- `src/PIM/Common` - shared IR, filesystem, diagnostics, reports, and utility
|
||||
helpers.
|
||||
- `src/PIM/Compiler` - PIM compiler options, memory/address planning, binary
|
||||
instruction format, artifact writing, weight emission, and codegen entry
|
||||
points.
|
||||
- `src/PIM/Conversion/SpatialToGraphviz` - optional Spatial graphviz conversion
|
||||
pass.
|
||||
- `src/PIM/Pass` - pass registration and auxiliary passes.
|
||||
- `src/PIM/PimAccelerator.{cpp,hpp}` - ONNX-MLIR accelerator entry point.
|
||||
|
||||
## Key compiler options
|
||||
|
||||
Pass these on the `onnx-mlir` command line when compiling for PIM:
|
||||
Pass these to `onnx-mlir` when compiling for PIM:
|
||||
|
||||
- `--maccel=PIM` — select the PIM accelerator.
|
||||
- `--EmitSpatial` / `--EmitPim` / `--EmitPimBufferized` / `--EmitPimCodegen`
|
||||
— stop the pipeline at the requested stage (default: `EmitPimCodegen`).
|
||||
- `--pim-only-codegen` — assume the input is already bufferized PIM IR and
|
||||
run only the codegen tail.
|
||||
- `--crossbar-size=<N>` / `--crossbar-count=<N>` — crossbar dimensions and
|
||||
per-core count.
|
||||
- `--core-count=<N>` — number of cores. Required for PIM compilation.
|
||||
- `--pim-merge-scheduler={peft,dcp}` — scheduler used by the Spatial
|
||||
merge-compute-nodes pass (default: `peft`).
|
||||
- `--dcp-critical-window-size=<N>` — DCP coarsening window (0 = legacy).
|
||||
- `--use-experimental-conv-impl` — alternative convolution lowering.
|
||||
- `--ignore-concat-error` — soft-fail corner case in `ConcatOp`.
|
||||
- `--maccel=PIM` - select the PIM accelerator.
|
||||
- `--EmitSpatial`, `--EmitPim`, `--EmitPimBufferized`,
|
||||
`--EmitPimCodegen` - stop the PIM pipeline at the requested stage. The PIM
|
||||
default is `--EmitPimCodegen`.
|
||||
- `--core-count=<N>` - required positive core count for PIM compilation.
|
||||
- `--crossbar-size=<N>` - crossbar width/height. Default in code is `2`.
|
||||
- `--crossbar-count=<N>` - crossbars per core. Default in code is `256`.
|
||||
- `--pim-merge-scheduler=peft` - merge scheduler. `peft` is the only accepted
|
||||
value in the current code.
|
||||
- `--pim-only-codegen` - assume input is already bufferized PIM IR and only run
|
||||
the codegen tail.
|
||||
- `--pim-emit-json` - also emit `core_*.json` instruction files alongside
|
||||
`core_*.pim`.
|
||||
- `--use-experimental-conv-impl` - use the alternate convolution lowering.
|
||||
- `--ignore-concat-error` - soft-fail a ConcatOp corner case.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
./build_release/Release/bin/onnx-mlir model.onnx -o /tmp/raptor/model \
|
||||
--maccel=PIM --EmitPimCodegen \
|
||||
--crossbar-size=2048 --crossbar-count=256 --core-count=1000
|
||||
```
|
||||
|
||||
This writes PIM artifacts under `/tmp/raptor/pim/`.
|
||||
|
||||
## Validation
|
||||
|
||||
Functional validation lives in `validation/` and drives the Rust
|
||||
`pim-simulator` to compare Raptor's output against a reference.
|
||||
Functional validation lives in `validation/`. It compiles ONNX models, builds a
|
||||
native ONNX-MLIR reference runner, generates random inputs, runs Raptor, runs
|
||||
the Rust PIM simulator, and compares outputs.
|
||||
|
||||
Per-operation validation (from `validation/`):
|
||||
Python dependencies used by the validation scripts are `numpy`, `onnx`, and
|
||||
`colorama`. The simulator requires the Rust toolchain.
|
||||
|
||||
```
|
||||
validate.py \
|
||||
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||
--onnx-include-dir ../onnx-mlir/include \
|
||||
Per-operation validation from the repository root:
|
||||
|
||||
```bash
|
||||
python3 validation/validate.py \
|
||||
--raptor-path build_release/Release/bin/onnx-mlir \
|
||||
--onnx-include-dir onnx-mlir/include \
|
||||
--core-count 1000
|
||||
```
|
||||
|
||||
End-to-end network validation (example: first 4 layers of YOLOv11n):
|
||||
Validate one network or a subset by pointing `--operations-dir` at any directory
|
||||
containing `.onnx` files:
|
||||
|
||||
```
|
||||
validate.py \
|
||||
--raptor-path ../cmake-build-release/Release/bin/onnx-mlir \
|
||||
--onnx-include-dir ../onnx-mlir/include \
|
||||
--operations-dir ./networks/yolo11n/depth_04 \
|
||||
```bash
|
||||
python3 validation/validate.py \
|
||||
--raptor-path build_release/Release/bin/onnx-mlir \
|
||||
--onnx-include-dir onnx-mlir/include \
|
||||
--operations-dir validation/networks/yolo11n/depth_04 \
|
||||
--crossbar-size 2048 --crossbar-count 256 --core-count 1000
|
||||
```
|
||||
|
||||
Available networks under `validation/networks/`: `vgg16`, `yolo11n`.
|
||||
Available operations under `validation/operations/`: `add`, `conv`, `div`,
|
||||
`gather`, `gemm`, `gemv`, `mul`, `pool`, `reduce_mean`, `relu`, `resize`,
|
||||
`sigmoid`, `softmax`, `split`.
|
||||
Useful validation options:
|
||||
- `--simulator-dir <path>` - override the auto-detected
|
||||
`backend-simulators/pim/pim-simulator` path.
|
||||
- `--threshold <float>` - maximum allowed per-element output difference.
|
||||
- `--seed <int>` - RNG seed for generated inputs.
|
||||
- `--command-timeout-seconds <float>` - timeout for compiler, runner, and
|
||||
simulator subprocesses.
|
||||
- `--verbose` - print subprocess logs and average PIM pass timings.
|
||||
- `--clean` - remove generated validation artifacts and exit.
|
||||
|
||||
## Rebuilding
|
||||
Each validation run writes artifacts in the model workspace, for example under
|
||||
`validation/operations/gemm/small/`:
|
||||
- `inputs/` - generated input CSV files.
|
||||
- `outputs/` - native ONNX-MLIR reference outputs.
|
||||
- `raptor/` - compiler artifacts, including `*.onnx.mlir`, dialect dumps under
|
||||
`dialects/`, reports under `reports/`, and final PIM artifacts under `pim/`.
|
||||
- `runner/` - generated reference runner source, build tree, and shared library.
|
||||
- `simulation/out.bin` - raw simulator output used for comparison.
|
||||
|
||||
Release build (fast):
|
||||
The compiler currently dumps dialect snapshots such as `spatial0.mlir`,
|
||||
`spatial1_dcp_merged.mlir`, `pim0.mlir`, `pim1_buff.mlir`,
|
||||
`pim2_coalesced.mlir`, `pim3_folded.mlir`, and
|
||||
`pim4_materialized.mlir` when an output directory is available.
|
||||
|
||||
```
|
||||
cmake --build /home/nico/raptor/raptor/cmake-build-release --target onnx-mlir -j 30
|
||||
To rerun the simulator manually with tracing after validation has produced a
|
||||
`raptor/pim/` directory:
|
||||
|
||||
```bash
|
||||
cd backend-simulators/pim/pim-simulator
|
||||
cargo run --no-default-features --features tracing --release \
|
||||
--package pim-simulator --bin pim-simulator -- \
|
||||
-f /path/to/workspace/raptor/pim \
|
||||
-o /path/to/workspace/simulation/out.bin \
|
||||
-d <addr0>,<size0>,<addr1>,<size1>,...
|
||||
```
|
||||
|
||||
A slower debug build is also available — configure it the same way but with
|
||||
`-DCMAKE_BUILD_TYPE=Debug` (see installation instructions below).
|
||||
With `--features tracing`, the simulator writes per-core traces as
|
||||
`TraceCore0`, `TraceCore1`, ... next to `out.bin`. The validator normally
|
||||
computes the `-d` ranges from `raptor/pim/config.json` and model output shapes.
|
||||
|
||||
Available validation networks under `validation/networks/`: `vgg16`,
|
||||
`yolo11n`, `yolo11nv2`.
|
||||
|
||||
Available operation suites under `validation/operations/`: `add`, `concat`,
|
||||
`conv`, `div`, `gather`, `gemm`, `gemv`, `matmul`, `mul`, `pool`,
|
||||
`reduce_mean`, `relu`, `reshape`, `resize`, `sigmoid`, `softmax`, `split`.
|
||||
|
||||
Generated operation tests can be regenerated with:
|
||||
|
||||
```bash
|
||||
python3 validation/operations/gen_tests.py
|
||||
```
|
||||
|
||||
## Build
|
||||
|
||||
Initialize submodules first:
|
||||
|
||||
```bash
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
The project follows ONNX-MLIR's build requirements. The CI workflow documents
|
||||
the currently used versions and setup:
|
||||
- CMake 4.3.0 in CI,
|
||||
- LLVM/MLIR checked out under `onnx-mlir/llvm-project`,
|
||||
- Protobuf `v34.0`,
|
||||
- Rust stable for `pim-simulator`,
|
||||
- Python packages `numpy`, `onnx`, `colorama` for validation.
|
||||
|
||||
### Protobuf
|
||||
|
||||
Use the following commands to install protobuf:
|
||||
```
|
||||
Install Protobuf if your system does not already provide a compatible version:
|
||||
|
||||
```bash
|
||||
git clone --depth 1 --branch v34.0 https://github.com/protocolbuffers/protobuf
|
||||
cd protobuf
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release
|
||||
ninja
|
||||
sudo ninja install
|
||||
cmake -S protobuf -B protobuf/build -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-Dprotobuf_BUILD_TESTS=OFF
|
||||
cmake --build protobuf/build
|
||||
sudo cmake --install protobuf/build
|
||||
```
|
||||
|
||||
You can now remove the protobuf repo directory with:
|
||||
```
|
||||
cd ../..
|
||||
You can then remove the temporary checkout:
|
||||
|
||||
```bash
|
||||
rm -rf protobuf
|
||||
```
|
||||
|
||||
### Mlir
|
||||
### MLIR
|
||||
|
||||
Follow the first part of instructions [here](onnx-mlir/docs/BuildOnLinuxOSX.md) to build mlir.
|
||||
Follow the ONNX-MLIR instructions in
|
||||
`onnx-mlir/docs/BuildOnLinuxOSX.md` to build LLVM/MLIR. The local Raptor build
|
||||
expects `MLIR_DIR` to point at the MLIR CMake package, for example:
|
||||
|
||||
Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor
|
||||
|
||||
Moreover, if compiling with build type debug, it is also suggested to use
|
||||
mold as linker (you will need to install it if you don't have it already)
|
||||
to reduce memory usage during linking. You can use it by setting the options:
|
||||
```
|
||||
-DLLVM_USE_LINKER=mold
|
||||
```bash
|
||||
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build_release/lib/cmake/mlir
|
||||
```
|
||||
|
||||
If your LLVM build directory is named `build` instead of `build_release`, adjust
|
||||
the path accordingly.
|
||||
|
||||
### Raptor
|
||||
|
||||
Use the following commands to build Raptor.
|
||||
Configure a release build:
|
||||
|
||||
Remember to set ```-DCMAKE_BUILD_TYPE=Debug``` for developing on Raptor.
|
||||
|
||||
Also in this case, it is suggested to use mold as linker to reduce link time and memory usage,
|
||||
setting the options:
|
||||
```
|
||||
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
|
||||
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold" \
|
||||
-DCMAKE_MODULE_LINKER_FLAGS="-fuse-ld=mold"
|
||||
```
|
||||
|
||||
```
|
||||
git submodule update --init --recursive
|
||||
|
||||
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build/lib/cmake/mlir
|
||||
mkdir build && cd build
|
||||
cmake .. -G Ninja \
|
||||
```bash
|
||||
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build_release/lib/cmake/mlir
|
||||
cmake -S . -B build_release -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DONNX_MLIR_ACCELERATORS=PIM \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DMLIR_DIR=${MLIR_DIR}
|
||||
cmake --build .
|
||||
```
|
||||
|
||||
If the build fails because of protobuf missing uint definitions,
|
||||
just patch the problematic files by adding ```#include <cstdint>``` to their includes.
|
||||
Configure a debug build similarly:
|
||||
|
||||
```bash
|
||||
MLIR_DIR=$(pwd)/onnx-mlir/llvm-project/build_debug/lib/cmake/mlir
|
||||
cmake -S . -B build_debug -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Debug \
|
||||
-DONNX_MLIR_ACCELERATORS=PIM \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DMLIR_DIR=${MLIR_DIR}
|
||||
```
|
||||
|
||||
For debug development, using `mold` can reduce link time and memory use:
|
||||
|
||||
```bash
|
||||
cmake -S . -B build_debug -G Ninja \
|
||||
-DCMAKE_BUILD_TYPE=Debug \
|
||||
-DONNX_MLIR_ACCELERATORS=PIM \
|
||||
-DLLVM_ENABLE_ASSERTIONS=ON \
|
||||
-DMLIR_DIR=${MLIR_DIR} \
|
||||
-DCMAKE_EXE_LINKER_FLAGS="-fuse-ld=mold" \
|
||||
-DCMAKE_SHARED_LINKER_FLAGS="-fuse-ld=mold" \
|
||||
-DCMAKE_MODULE_LINKER_FLAGS="-fuse-ld=mold"
|
||||
```
|
||||
|
||||
Build the compiler with CMake:
|
||||
|
||||
```bash
|
||||
cmake --build ./build_release
|
||||
cmake --build ./build_debug
|
||||
```
|
||||
|
||||
Do not invoke `ninja` directly for this project; use `cmake --build` so CMake's
|
||||
configuration and generated shims stay consistent.
|
||||
|
||||
If a build fails because Protobuf headers are missing fixed-width integer
|
||||
definitions, patch the affected Protobuf-generated files by adding
|
||||
`#include <cstdint>`.
|
||||
|
||||
## Tests
|
||||
|
||||
The Rust simulator has its own tests:
|
||||
|
||||
```bash
|
||||
cd backend-simulators/pim/pim-simulator
|
||||
cargo test
|
||||
```
|
||||
|
||||
## Repository Layout
|
||||
|
||||
- `src/PIM/` - PIM accelerator implementation.
|
||||
- `test/PIM/` - PIM C++ unit tests.
|
||||
- `validation/` - functional validation scripts, ONNX operation tests, network
|
||||
slices, and pimsim config generation.
|
||||
- `backend-simulators/pim/pim-simulator/` - in-tree Rust functional simulator.
|
||||
- `backend-simulators/pim/pimsim-nn/` - performance simulator submodule.
|
||||
- `pimcomp_utils/` - local comparison helpers for PIMCOMP-NN.
|
||||
- `.github/actions/` and `.github/workflows/validate_operations.yml` - CI setup
|
||||
for MLIR/Protobuf caching, building Raptor, and validation.
|
||||
|
||||
@@ -43,7 +43,7 @@ struct Args {
|
||||
|
||||
/// Comma separated list of (address,size) for memory output dump
|
||||
#[arg(short, long, value_delimiter = ',', num_args = 1.., value_name = "ADDR,SIZE")]
|
||||
dump: Vec<i32>,
|
||||
dump: Vec<usize>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@@ -168,7 +168,7 @@ fn get_crossbars(config: &Value, args: &Args) -> anyhow::Result<HashMap<String,
|
||||
}
|
||||
|
||||
fn dump_memory(mut executor: pimcore::Executable, args: &Args) -> Result<()> {
|
||||
let dumps: Vec<(i32, i32)> = args
|
||||
let dumps: Vec<(usize, usize)> = args
|
||||
.dump
|
||||
.chunks_exact(2)
|
||||
.map(|chunk| (chunk[0], chunk[1]))
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::utility::AddressArg;
|
||||
use std::{collections::HashMap, fmt::Debug};
|
||||
use anyhow::{Context, Result, ensure};
|
||||
|
||||
@@ -9,6 +10,7 @@ use crate::{
|
||||
|
||||
pub mod crossbar;
|
||||
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CPU<'a> {
|
||||
cores: Box<[Core<'a>]>,
|
||||
@@ -91,30 +93,26 @@ impl<'a> Core<'a> {
|
||||
self.memory.execute_load()
|
||||
}
|
||||
|
||||
pub fn execute_store<T>(&mut self, address: impl TryToUsize, element: &[T]) -> Result<()>
|
||||
pub fn execute_store<T>(&mut self, address: impl AddressArg, element: &[T]) -> Result<()>
|
||||
where
|
||||
T: MemoryStorable,
|
||||
{
|
||||
let address = address.try_into().context("address can not be negative")?;
|
||||
let address = address.to_address_usize()?;
|
||||
self.memory.execute_store(address, element)
|
||||
}
|
||||
|
||||
pub fn reserve_load(
|
||||
&mut self,
|
||||
address: impl TryToUsize,
|
||||
address: impl AddressArg,
|
||||
size: impl TryToUsize,
|
||||
) -> Result<&mut CoreMemory> {
|
||||
let address = address.try_into().context("address can not be negative")?;
|
||||
let address = address.to_address_usize()?;
|
||||
let size = size.try_into().context("size can not be negative")?;
|
||||
self.memory.reserve_load(address, size)
|
||||
}
|
||||
|
||||
pub fn set_register(&mut self, index: impl TryToUsize, value: i32) {
|
||||
let index = index.try_into().expect("index can not be negative");
|
||||
assert!(
|
||||
value >= 0,
|
||||
"Register cannot be negative if happens remove this and go check where it's used as usize"
|
||||
);
|
||||
self.registers[index] = value;
|
||||
}
|
||||
|
||||
@@ -123,11 +121,11 @@ impl<'a> Core<'a> {
|
||||
self.registers[index]
|
||||
}
|
||||
|
||||
pub fn load<T>(&mut self, address: impl TryToUsize, size: impl TryToUsize) -> Result<Vec<&[T]>>
|
||||
pub fn load<T>(&mut self, address: impl AddressArg, size: impl TryToUsize) -> Result<Vec<&[T]>>
|
||||
where
|
||||
T: MemoryStorable,
|
||||
{
|
||||
let address = address.try_into().context("address can not be negative")?;
|
||||
let address = address.to_address_usize()?;
|
||||
let size = size.try_into().context("size can not be negative")?;
|
||||
self.memory.load(address, size)
|
||||
}
|
||||
@@ -141,8 +139,8 @@ impl<'a> Core<'a> {
|
||||
(memory, crossbars)
|
||||
}
|
||||
|
||||
pub fn memset(&mut self, address: impl TryToUsize, size: impl TryToUsize, val: u8) -> Result<()> {
|
||||
let address = address.try_into().context("address can not be negative")?;
|
||||
pub fn memset(&mut self, address: impl AddressArg, size: impl TryToUsize, val: u8) -> Result<()> {
|
||||
let address = address.to_address_usize()?;
|
||||
let size = size.try_into().context("size can not be negative")?;
|
||||
self.memory.memset(address, size, val)
|
||||
}
|
||||
|
||||
@@ -299,10 +299,11 @@ fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockIn
|
||||
if in_path.contains(&waiting_for) {
|
||||
let cycle_start = path.iter().position(|&c| c == waiting_for).unwrap();
|
||||
let cycle = &path[cycle_start..];
|
||||
let format_core = |core: &i32| (core - 1).to_string();
|
||||
|
||||
let cycle_str = cycle
|
||||
.iter()
|
||||
.map(|c| c.to_string())
|
||||
.map(format_core)
|
||||
.collect::<Vec<_>>()
|
||||
.join(" -> ");
|
||||
|
||||
@@ -311,19 +312,19 @@ fn detect_deadlock(cores_instructions: &[CoreInstructions]) -> Option<DeadlockIn
|
||||
.copied()
|
||||
.chain(std::iter::once(waiting_for))
|
||||
.collect::<Vec<_>>();
|
||||
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for);
|
||||
let cycle_msg = format!("{} -> {}", cycle_str, waiting_for - 1);
|
||||
let states_msg = cycle
|
||||
.iter()
|
||||
.filter_map(|core| {
|
||||
states.get(core).map(|state| match state {
|
||||
CoreState::SendingTo(target, size) => {
|
||||
format!("core {} send {}B -> {}", core, size, target)
|
||||
format!("core {} send {}B -> {}", core - 1, size, target - 1)
|
||||
}
|
||||
CoreState::ReceivingFrom(source, size) => {
|
||||
format!("core {} recv {}B <- {}", core, size, source)
|
||||
format!("core {} recv {}B <- {}", core - 1, size, source - 1)
|
||||
}
|
||||
CoreState::Working => format!("core {} working", core),
|
||||
CoreState::Halted => format!("core {} halted", core),
|
||||
CoreState::Working => format!("core {} working", core - 1),
|
||||
CoreState::Halted => format!("core {} halted", core - 1),
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
|
||||
@@ -1,7 +1,45 @@
|
||||
use anyhow::{Result,Context};
|
||||
use std::{fmt::Debug, mem::transmute};
|
||||
|
||||
use crate::memory_manager::type_traits::TryToUsize;
|
||||
|
||||
pub trait AddressArg {
|
||||
fn to_address_usize(self) -> Result<usize>;
|
||||
}
|
||||
|
||||
impl AddressArg for usize {
|
||||
fn to_address_usize(self) -> Result<usize> {
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl AddressArg for u32 {
|
||||
fn to_address_usize(self) -> Result<usize> {
|
||||
Ok(self as usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl AddressArg for u64 {
|
||||
fn to_address_usize(self) -> Result<usize> {
|
||||
usize::try_from(self).context("address does not fit in usize")
|
||||
}
|
||||
}
|
||||
|
||||
impl AddressArg for i32 {
|
||||
fn to_address_usize(self) -> Result<usize> {
|
||||
Ok(self as u32 as usize)
|
||||
}
|
||||
}
|
||||
|
||||
impl AddressArg for i64 {
|
||||
fn to_address_usize(self) -> Result<usize> {
|
||||
usize::try_from(self).context("address can not be negative")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
fn address_to_usize(address: i32) -> usize {
|
||||
address as u32 as usize
|
||||
}
|
||||
|
||||
fn add_offset_impl(address: usize, offset_select : i32, offset_value : i32, id:i32) -> usize{
|
||||
assert!(offset_select == 1 || offset_select == 2 || offset_select == 4 || offset_value == 0, "offset_select not a bit field");
|
||||
@@ -14,21 +52,21 @@ fn add_offset_impl(address: usize, offset_select : i32, offset_value : i32, id:i
|
||||
}
|
||||
|
||||
|
||||
pub fn add_offset_rd(address: impl TryToUsize, offset_select : i32, offset_value : i32) -> usize
|
||||
pub fn add_offset_rd(address: i32, offset_select : i32, offset_value : i32) -> usize
|
||||
{
|
||||
let address = address.try_into().expect("address can not be negative");
|
||||
let address = address_to_usize(address);
|
||||
add_offset_impl(address, offset_select, offset_value, 4)
|
||||
}
|
||||
|
||||
pub fn add_offset_r1(address: impl TryToUsize, offset_select : i32, offset_value : i32) -> usize
|
||||
pub fn add_offset_r1(address: i32, offset_select : i32, offset_value : i32) -> usize
|
||||
{
|
||||
let address = address.try_into().expect("address can not be negative");
|
||||
let address = address_to_usize(address);
|
||||
add_offset_impl(address, offset_select, offset_value, 1)
|
||||
}
|
||||
|
||||
pub fn add_offset_r2(address: impl TryToUsize, offset_select : i32, offset_value : i32) -> usize
|
||||
pub fn add_offset_r2(address: i32, offset_select : i32, offset_value : i32) -> usize
|
||||
{
|
||||
let address = address.try_into().expect("address can not be negative");
|
||||
let address = address_to_usize(address);
|
||||
add_offset_impl(address, offset_select, offset_value, 2)
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,56 @@ set(PIM_INCLUDE_PATH ${CMAKE_INCLUDE_OUTPUT_DIRECTORY})
|
||||
set(PIM_ONNX_MLIR_SRC_ROOT ${ONNX_MLIR_SRC_ROOT})
|
||||
set(PIM_ONNX_MLIR_BIN_ROOT ${ONNX_MLIR_BIN_ROOT})
|
||||
|
||||
set(PIM_GENERATED_PATH_SHIM_TARGET "")
|
||||
get_filename_component(PIM_BIN_ROOT_NAME "${PIM_BIN_ROOT}" NAME)
|
||||
if (PIM_BIN_ROOT_NAME STREQUAL "raptor-external")
|
||||
get_filename_component(PIM_GENERATED_PATH_SHIM_ROOT "${PIM_BIN_ROOT}" DIRECTORY)
|
||||
set(PIM_GENERATED_PATH_SHIM_OUTPUTS)
|
||||
|
||||
function(add_pim_generated_path_shim relative_path)
|
||||
set(real_file "${PIM_BIN_ROOT}/${relative_path}")
|
||||
set(shim_file "${PIM_GENERATED_PATH_SHIM_ROOT}/${relative_path}")
|
||||
get_filename_component(shim_dir "${shim_file}" DIRECTORY)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT "${shim_file}"
|
||||
DEPENDS "${real_file}"
|
||||
COMMAND "${CMAKE_COMMAND}" -E make_directory "${shim_dir}"
|
||||
COMMAND "${CMAKE_COMMAND}" -E rm -f "${shim_file}"
|
||||
COMMAND "${CMAKE_COMMAND}" -E create_symlink "${real_file}" "${shim_file}"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
list(APPEND PIM_GENERATED_PATH_SHIM_OUTPUTS "${shim_file}")
|
||||
set(PIM_GENERATED_PATH_SHIM_OUTPUTS "${PIM_GENERATED_PATH_SHIM_OUTPUTS}" PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
file(GLOB_RECURSE pim_generated_path_scan_sources
|
||||
CONFIGURE_DEPENDS
|
||||
"${PIM_SRC_ROOT}/*.cpp"
|
||||
"${PIM_SRC_ROOT}/*.hpp"
|
||||
)
|
||||
|
||||
set(pim_generated_path_shims)
|
||||
foreach (source_file IN LISTS pim_generated_path_scan_sources)
|
||||
file(READ "${source_file}" source_contents)
|
||||
string(REGEX MATCHALL "#include \"src/Accelerators/PIM/[^\"]+\\.inc\"" source_inc_matches "${source_contents}")
|
||||
|
||||
foreach (inc_match IN LISTS source_inc_matches)
|
||||
string(REGEX REPLACE "^#include \"src/Accelerators/PIM/(.+)\"$" "\\1" relative_inc_path "${inc_match}")
|
||||
list(APPEND pim_generated_path_shims "${relative_inc_path}")
|
||||
endforeach ()
|
||||
endforeach ()
|
||||
|
||||
list(REMOVE_DUPLICATES pim_generated_path_shims)
|
||||
foreach (relative_inc_path IN LISTS pim_generated_path_shims)
|
||||
add_pim_generated_path_shim("${relative_inc_path}")
|
||||
endforeach ()
|
||||
|
||||
add_custom_target(OMPimGeneratedPathShims DEPENDS ${PIM_GENERATED_PATH_SHIM_OUTPUTS})
|
||||
set(PIM_GENERATED_PATH_SHIM_TARGET OMPimGeneratedPathShims)
|
||||
endif ()
|
||||
|
||||
set(PIM_PUBLIC_INCLUDE_DIRS
|
||||
${ONNX_MLIR_SRC_ROOT}/include
|
||||
${ONNX_MLIR_SRC_ROOT}
|
||||
@@ -37,6 +87,9 @@ set(PIM_GENERATED_INCLUDE_DIRS
|
||||
|
||||
function(add_pim_library name)
|
||||
add_onnx_mlir_library(${name} STATIC ${ARGN})
|
||||
if (PIM_GENERATED_PATH_SHIM_TARGET)
|
||||
add_dependencies(${name} ${PIM_GENERATED_PATH_SHIM_TARGET})
|
||||
endif ()
|
||||
endfunction()
|
||||
|
||||
add_subdirectory(Dialect)
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
add_pim_library(OMPimCommon
|
||||
IR/AffineUtils.cpp
|
||||
IR/AddressAnalysis.cpp
|
||||
IR/BatchCoreUtils.cpp
|
||||
IR/ConstantUtils.cpp
|
||||
IR/CoreBlockUtils.cpp
|
||||
IR/EntryPointUtils.cpp
|
||||
IR/ShapeUtils.cpp
|
||||
@@ -16,6 +19,7 @@ add_pim_library(OMPimCommon
|
||||
${PIM_PUBLIC_INCLUDE_DIRS}
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRLinalgDialect
|
||||
onnx
|
||||
SpatialOps
|
||||
PimOps
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
@@ -28,6 +32,14 @@ mlir::Value resolveAlias(mlir::Value value, const StaticValueKnowledge* knowledg
|
||||
return value;
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value);
|
||||
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value);
|
||||
|
||||
template <typename... Args>
|
||||
CompiledIndexExpr makeCompiledIndexExpr(Args&&... args) {
|
||||
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
|
||||
@@ -55,6 +67,288 @@ mlir::Value resolveLoopCarriedAliasImpl(mlir::Value value, const StaticValueKnow
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveOpFoldResult(mlir::OpFoldResult ofr, const StaticValueKnowledge* knowledge);
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge);
|
||||
|
||||
static llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticMemRefStrides(mlir::MemRefType type) {
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
if (failed(type.getStridesAndOffset(strides, offset)))
|
||||
return mlir::failure();
|
||||
if (llvm::any_of(strides, mlir::ShapedType::isDynamic))
|
||||
return mlir::failure();
|
||||
return strides;
|
||||
}
|
||||
|
||||
static llvm::FailureOr<int64_t> resolveConstantGlobalLoad(mlir::memref::LoadOp loadOp,
|
||||
const StaticValueKnowledge* knowledge) {
|
||||
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return mlir::failure();
|
||||
|
||||
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||
return mlir::failure();
|
||||
|
||||
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
|
||||
if (!denseAttr || !globalType || !globalType.hasStaticShape())
|
||||
return mlir::failure();
|
||||
|
||||
auto elementType = denseAttr.getElementType();
|
||||
if (!elementType.isIndex() && !elementType.isInteger())
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t> indices;
|
||||
indices.reserve(loadOp.getIndices().size());
|
||||
for (mlir::Value index : loadOp.getIndices()) {
|
||||
auto resolvedIndex = resolveIndexValueImpl(index, knowledge);
|
||||
if (failed(resolvedIndex))
|
||||
return mlir::failure();
|
||||
indices.push_back(*resolvedIndex);
|
||||
}
|
||||
|
||||
if (indices.size() != static_cast<size_t>(globalType.getRank()))
|
||||
return mlir::failure();
|
||||
|
||||
auto strides = computeRowMajorStrides(globalType.getShape());
|
||||
int64_t linearIndex = linearizeIndex(indices, strides);
|
||||
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
|
||||
return mlir::failure();
|
||||
|
||||
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
|
||||
}
|
||||
|
||||
static bool evaluateCmpPredicate(mlir::arith::CmpIPredicate predicate, int64_t lhs, int64_t rhs) {
|
||||
switch (predicate) {
|
||||
case mlir::arith::CmpIPredicate::eq: return lhs == rhs;
|
||||
case mlir::arith::CmpIPredicate::ne: return lhs != rhs;
|
||||
case mlir::arith::CmpIPredicate::slt: return lhs < rhs;
|
||||
case mlir::arith::CmpIPredicate::sle: return lhs <= rhs;
|
||||
case mlir::arith::CmpIPredicate::sgt: return lhs > rhs;
|
||||
case mlir::arith::CmpIPredicate::sge: return lhs >= rhs;
|
||||
case mlir::arith::CmpIPredicate::ult: return static_cast<uint64_t>(lhs) < static_cast<uint64_t>(rhs);
|
||||
case mlir::arith::CmpIPredicate::ule: return static_cast<uint64_t>(lhs) <= static_cast<uint64_t>(rhs);
|
||||
case mlir::arith::CmpIPredicate::ugt: return static_cast<uint64_t>(lhs) > static_cast<uint64_t>(rhs);
|
||||
case mlir::arith::CmpIPredicate::uge: return static_cast<uint64_t>(lhs) >= static_cast<uint64_t>(rhs);
|
||||
}
|
||||
|
||||
llvm_unreachable("unknown cmpi predicate");
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> evaluateCompiledIndexExpr(const CompiledIndexExpr& expr,
|
||||
const StaticValueKnowledge& knowledge) {
|
||||
if (!expr.node)
|
||||
return mlir::failure();
|
||||
|
||||
switch (expr.node->kind) {
|
||||
case CompiledIndexExprNode::Kind::Constant: return expr.node->constant;
|
||||
case CompiledIndexExprNode::Kind::Symbol: {
|
||||
auto value = resolveAlias(expr.node->symbol, &knowledge);
|
||||
auto iter = knowledge.indexValues.find(value);
|
||||
if (iter != knowledge.indexValues.end())
|
||||
return iter->second;
|
||||
return mlir::failure();
|
||||
}
|
||||
case CompiledIndexExprNode::Kind::Add:
|
||||
case CompiledIndexExprNode::Kind::Sub:
|
||||
case CompiledIndexExprNode::Kind::Mul:
|
||||
case CompiledIndexExprNode::Kind::DivUI:
|
||||
case CompiledIndexExprNode::Kind::DivSI:
|
||||
case CompiledIndexExprNode::Kind::RemUI:
|
||||
case CompiledIndexExprNode::Kind::RemSI:
|
||||
case CompiledIndexExprNode::Kind::MinUI:
|
||||
case CompiledIndexExprNode::Kind::CmpI: {
|
||||
auto lhs = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
|
||||
auto rhs = evaluateCompiledIndexExpr(expr.node->operands[1], knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
|
||||
switch (expr.node->kind) {
|
||||
case CompiledIndexExprNode::Kind::Add: return *lhs + *rhs;
|
||||
case CompiledIndexExprNode::Kind::Sub: return *lhs - *rhs;
|
||||
case CompiledIndexExprNode::Kind::Mul: return *lhs * *rhs;
|
||||
case CompiledIndexExprNode::Kind::DivUI:
|
||||
if (*rhs == 0)
|
||||
return mlir::failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||
case CompiledIndexExprNode::Kind::DivSI:
|
||||
if (*rhs == 0 || (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1))
|
||||
return mlir::failure();
|
||||
return *lhs / *rhs;
|
||||
case CompiledIndexExprNode::Kind::RemUI:
|
||||
if (*rhs == 0)
|
||||
return mlir::failure();
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||
case CompiledIndexExprNode::Kind::RemSI:
|
||||
if (*rhs == 0)
|
||||
return mlir::failure();
|
||||
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
|
||||
return 0;
|
||||
return *lhs % *rhs;
|
||||
case CompiledIndexExprNode::Kind::MinUI:
|
||||
return static_cast<int64_t>(std::min(static_cast<uint64_t>(*lhs), static_cast<uint64_t>(*rhs)));
|
||||
case CompiledIndexExprNode::Kind::CmpI: return evaluateCmpPredicate(expr.node->predicate, *lhs, *rhs) ? 1 : 0;
|
||||
default: llvm_unreachable("unexpected binary compiled index kind");
|
||||
}
|
||||
}
|
||||
case CompiledIndexExprNode::Kind::Select: {
|
||||
auto condition = evaluateCompiledIndexExpr(expr.node->operands[0], knowledge);
|
||||
if (failed(condition))
|
||||
return mlir::failure();
|
||||
return evaluateCompiledIndexExpr(*condition != 0 ? expr.node->operands[1] : expr.node->operands[2], knowledge);
|
||||
}
|
||||
case CompiledIndexExprNode::Kind::ConstantGlobalLoad: {
|
||||
if (!expr.node->globalOp || !expr.node->globalOp.getInitialValue())
|
||||
return mlir::failure();
|
||||
|
||||
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*expr.node->globalOp.getInitialValue());
|
||||
auto globalType = mlir::dyn_cast<mlir::MemRefType>(expr.node->globalOp.getType());
|
||||
if (!denseAttr || !globalType)
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t> indices;
|
||||
indices.reserve(expr.node->operands.size());
|
||||
for (const CompiledIndexExpr& operand : expr.node->operands) {
|
||||
auto resolvedIndex = evaluateCompiledIndexExpr(operand, knowledge);
|
||||
if (failed(resolvedIndex))
|
||||
return mlir::failure();
|
||||
indices.push_back(*resolvedIndex);
|
||||
}
|
||||
|
||||
int64_t linearIndex = linearizeIndex(indices, expr.node->globalStrides);
|
||||
if (linearIndex < 0 || linearIndex >= globalType.getNumElements())
|
||||
return mlir::failure();
|
||||
return denseAttr.getValues<llvm::APInt>()[linearIndex].getSExtValue();
|
||||
}
|
||||
}
|
||||
|
||||
llvm_unreachable("unknown compiled index kind");
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledIndexExpr> compileConstantGlobalLoad(mlir::memref::LoadOp loadOp) {
|
||||
auto getGlobalOp = loadOp.getMemRef().getDefiningOp<mlir::memref::GetGlobalOp>();
|
||||
if (!getGlobalOp)
|
||||
return mlir::failure();
|
||||
|
||||
auto moduleOp = loadOp->getParentOfType<mlir::ModuleOp>();
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getConstant() || !globalOp.getInitialValue())
|
||||
return mlir::failure();
|
||||
|
||||
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
auto globalType = mlir::dyn_cast<mlir::MemRefType>(getGlobalOp.getType());
|
||||
if (!denseAttr || !globalType || !globalType.hasStaticShape())
|
||||
return mlir::failure();
|
||||
|
||||
auto elementType = denseAttr.getElementType();
|
||||
if (!elementType.isIndex() && !elementType.isInteger())
|
||||
return mlir::failure();
|
||||
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::ConstantGlobalLoad;
|
||||
expr.globalOp = globalOp;
|
||||
expr.globalStrides = computeRowMajorStrides(globalType.getShape());
|
||||
expr.operands.reserve(loadOp.getIndices().size());
|
||||
for (mlir::Value index : loadOp.getIndices()) {
|
||||
auto compiledIndex = compileIndexValueImpl(index);
|
||||
if (failed(compiledIndex))
|
||||
return mlir::failure();
|
||||
expr.operands.push_back(*compiledIndex);
|
||||
}
|
||||
return makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledIndexExpr> compileIndexValueImpl(mlir::Value value) {
|
||||
if (auto constantOp = value.getDefiningOp<mlir::arith::ConstantOp>()) {
|
||||
if (auto integerAttr = mlir::dyn_cast<mlir::IntegerAttr>(constantOp.getValue())) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = integerAttr.getInt();
|
||||
return makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
}
|
||||
|
||||
mlir::Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Symbol;
|
||||
expr.symbol = value;
|
||||
return makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
auto buildBinaryExpr = [&](CompiledIndexExprNode::Kind kind, mlir::Value lhsValue, mlir::Value rhsValue) {
|
||||
auto lhs = compileIndexValueImpl(lhsValue);
|
||||
auto rhs = compileIndexValueImpl(rhsValue);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return llvm::FailureOr<CompiledIndexExpr>(mlir::failure());
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = kind;
|
||||
expr.operands = {*lhs, *rhs};
|
||||
return llvm::FailureOr<CompiledIndexExpr>(makeCompiledIndexExpr(std::move(expr)));
|
||||
};
|
||||
|
||||
if (auto indexCastOp = mlir::dyn_cast<mlir::arith::IndexCastOp>(definingOp))
|
||||
return compileIndexValueImpl(indexCastOp.getIn());
|
||||
if (auto addOp = mlir::dyn_cast<mlir::arith::AddIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::Add, addOp.getLhs(), addOp.getRhs());
|
||||
if (auto subOp = mlir::dyn_cast<mlir::arith::SubIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::Sub, subOp.getLhs(), subOp.getRhs());
|
||||
if (auto mulOp = mlir::dyn_cast<mlir::arith::MulIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::Mul, mulOp.getLhs(), mulOp.getRhs());
|
||||
if (auto divOp = mlir::dyn_cast<mlir::arith::DivUIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivUI, divOp.getLhs(), divOp.getRhs());
|
||||
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::DivSI, divOp.getLhs(), divOp.getRhs());
|
||||
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::MinUI, minOp.getLhs(), minOp.getRhs());
|
||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemUIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemUI, remOp.getLhs(), remOp.getRhs());
|
||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp))
|
||||
return buildBinaryExpr(CompiledIndexExprNode::Kind::RemSI, remOp.getLhs(), remOp.getRhs());
|
||||
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
|
||||
auto expr = buildBinaryExpr(CompiledIndexExprNode::Kind::CmpI, cmpOp.getLhs(), cmpOp.getRhs());
|
||||
if (failed(expr))
|
||||
return mlir::failure();
|
||||
auto exprNode = std::make_shared<CompiledIndexExprNode>(*expr->node);
|
||||
exprNode->predicate = cmpOp.getPredicate();
|
||||
return CompiledIndexExpr(exprNode);
|
||||
}
|
||||
if (auto maxOp = mlir::dyn_cast<mlir::arith::MaxUIOp>(definingOp)) {
|
||||
auto lhs = compileIndexValueImpl(maxOp.getLhs());
|
||||
auto rhs = compileIndexValueImpl(maxOp.getRhs());
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
|
||||
CompiledIndexExprNode cmpExpr;
|
||||
cmpExpr.kind = CompiledIndexExprNode::Kind::CmpI;
|
||||
cmpExpr.predicate = mlir::arith::CmpIPredicate::uge;
|
||||
cmpExpr.operands = {*lhs, *rhs};
|
||||
|
||||
CompiledIndexExprNode selectExpr;
|
||||
selectExpr.kind = CompiledIndexExprNode::Kind::Select;
|
||||
selectExpr.operands = {makeCompiledIndexExpr(std::move(cmpExpr)), *lhs, *rhs};
|
||||
return makeCompiledIndexExpr(std::move(selectExpr));
|
||||
}
|
||||
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
|
||||
auto condition = compileIndexValueImpl(selectOp.getCondition());
|
||||
auto trueValue = compileIndexValueImpl(selectOp.getTrueValue());
|
||||
auto falseValue = compileIndexValueImpl(selectOp.getFalseValue());
|
||||
if (failed(condition) || failed(trueValue) || failed(falseValue))
|
||||
return mlir::failure();
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Select;
|
||||
expr.operands = {*condition, *trueValue, *falseValue};
|
||||
return makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
|
||||
return compileConstantGlobalLoad(loadOp);
|
||||
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Symbol;
|
||||
expr.symbol = value;
|
||||
return makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticValueKnowledge* knowledge) {
|
||||
value = resolveAlias(value, knowledge);
|
||||
@@ -110,6 +404,16 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) / static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto divOp = mlir::dyn_cast<mlir::arith::DivSIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return mlir::failure();
|
||||
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
|
||||
return mlir::failure();
|
||||
return *lhs / *rhs;
|
||||
}
|
||||
|
||||
if (auto minOp = mlir::dyn_cast<mlir::arith::MinUIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(minOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(minOp.getRhs(), knowledge);
|
||||
@@ -126,6 +430,34 @@ llvm::FailureOr<int64_t> resolveIndexValueImpl(mlir::Value value, const StaticVa
|
||||
return static_cast<int64_t>(static_cast<uint64_t>(*lhs) % static_cast<uint64_t>(*rhs));
|
||||
}
|
||||
|
||||
if (auto remOp = mlir::dyn_cast<mlir::arith::RemSIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs) || *rhs == 0)
|
||||
return mlir::failure();
|
||||
if (*lhs == std::numeric_limits<int64_t>::min() && *rhs == -1)
|
||||
return 0;
|
||||
return *lhs % *rhs;
|
||||
}
|
||||
|
||||
if (auto cmpOp = mlir::dyn_cast<mlir::arith::CmpIOp>(definingOp)) {
|
||||
auto lhs = resolveIndexValueImpl(cmpOp.getLhs(), knowledge);
|
||||
auto rhs = resolveIndexValueImpl(cmpOp.getRhs(), knowledge);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return mlir::failure();
|
||||
return evaluateCmpPredicate(cmpOp.getPredicate(), *lhs, *rhs) ? 1 : 0;
|
||||
}
|
||||
|
||||
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(definingOp)) {
|
||||
auto condition = resolveIndexValueImpl(selectOp.getCondition(), knowledge);
|
||||
if (failed(condition))
|
||||
return mlir::failure();
|
||||
return resolveIndexValueImpl(*condition != 0 ? selectOp.getTrueValue() : selectOp.getFalseValue(), knowledge);
|
||||
}
|
||||
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(definingOp))
|
||||
return resolveConstantGlobalLoad(loadOp, knowledge);
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
@@ -217,8 +549,10 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
|
||||
if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides))
|
||||
return mlir::failure();
|
||||
|
||||
auto sourceStrides = computeRowMajorStrides(sourceType.getShape());
|
||||
byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8;
|
||||
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||
if (failed(sourceStrides))
|
||||
return mlir::failure();
|
||||
byteOffset += linearizeIndex(offsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
value = resolveAlias(subviewOp.getSource(), knowledge);
|
||||
continue;
|
||||
}
|
||||
@@ -243,17 +577,206 @@ llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddressImpl(mlir::Va
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExprImpl(mlir::Value value) {
|
||||
int64_t constantByteOffset = 0;
|
||||
CompiledIndexExpr byteOffsetExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = 0;
|
||||
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value) { return resolveIndexValueImpl(value, nullptr); }
|
||||
while (true) {
|
||||
if (mlir::isa<mlir::BlockArgument>(value))
|
||||
return CompiledAddressExpr {value, byteOffsetExpr};
|
||||
|
||||
mlir::Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return mlir::failure();
|
||||
|
||||
if (auto dpsDefiningOp = mlir::dyn_cast<mlir::DestinationStyleOpInterface>(definingOp)) {
|
||||
mlir::OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(mlir::dyn_cast<mlir::OpResult>(value));
|
||||
if (!tiedOperand)
|
||||
return mlir::failure();
|
||||
value = tiedOperand->get();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(definingOp)) {
|
||||
auto result = mlir::dyn_cast<mlir::OpResult>(value);
|
||||
if (!result)
|
||||
return mlir::failure();
|
||||
|
||||
auto yieldOp = mlir::cast<mlir::scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
mlir::Value yieldedValue = yieldOp.getOperand(result.getResultNumber());
|
||||
if (auto blockArgument = mlir::dyn_cast<mlir::BlockArgument>(yieldedValue)) {
|
||||
if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0
|
||||
&& static_cast<unsigned>(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) {
|
||||
value = forOp.getInitArgs()[blockArgument.getArgNumber() - 1];
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
value = yieldedValue;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto subviewOp = mlir::dyn_cast<mlir::memref::SubViewOp>(definingOp)) {
|
||||
auto sourceType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getSource().getType());
|
||||
auto subviewType = mlir::dyn_cast<mlir::MemRefType>(subviewOp.getType());
|
||||
if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape())
|
||||
return mlir::failure();
|
||||
|
||||
llvm::SmallVector<int64_t> staticSizes;
|
||||
staticSizes.reserve(subviewOp.getMixedSizes().size());
|
||||
llvm::SmallVector<int64_t> staticStrides;
|
||||
staticStrides.reserve(subviewOp.getMixedStrides().size());
|
||||
llvm::SmallVector<int64_t> staticOffsets;
|
||||
staticOffsets.reserve(subviewOp.getMixedOffsets().size());
|
||||
bool hasOnlyStaticOffsets = true;
|
||||
|
||||
for (mlir::OpFoldResult offset : subviewOp.getMixedOffsets())
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
|
||||
staticOffsets.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
else
|
||||
hasOnlyStaticOffsets = false;
|
||||
for (mlir::OpFoldResult size : subviewOp.getMixedSizes()) {
|
||||
auto attr = mlir::dyn_cast<mlir::Attribute>(size);
|
||||
if (!attr)
|
||||
return mlir::failure();
|
||||
staticSizes.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
}
|
||||
for (mlir::OpFoldResult stride : subviewOp.getMixedStrides()) {
|
||||
auto attr = mlir::dyn_cast<mlir::Attribute>(stride);
|
||||
if (!attr)
|
||||
return mlir::failure();
|
||||
staticStrides.push_back(mlir::cast<mlir::IntegerAttr>(attr).getInt());
|
||||
}
|
||||
|
||||
if (!isContiguousSubviewWithDynamicOffsets(
|
||||
sourceType.getShape(), subviewOp.getMixedOffsets(), staticSizes, staticStrides)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (hasOnlyStaticOffsets) {
|
||||
if (!isMemoryContiguous(sourceType.getShape(), staticOffsets, staticSizes, staticStrides))
|
||||
return mlir::failure();
|
||||
|
||||
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||
if (failed(sourceStrides))
|
||||
return mlir::failure();
|
||||
constantByteOffset +=
|
||||
linearizeIndex(staticOffsets, *sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
}
|
||||
else {
|
||||
auto sourceStrides = getStaticMemRefStrides(sourceType);
|
||||
if (failed(sourceStrides))
|
||||
return mlir::failure();
|
||||
CompiledIndexExpr offsetExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = 0;
|
||||
offsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
for (auto [mixedOffset, sourceStride] : llvm::zip_equal(subviewOp.getMixedOffsets(), *sourceStrides)) {
|
||||
CompiledIndexExpr operandExpr;
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(mixedOffset)) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = mlir::cast<mlir::IntegerAttr>(attr).getInt() * sourceStride
|
||||
* getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
operandExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
else {
|
||||
auto compiledOffset = compileIndexValueImpl(mlir::cast<mlir::Value>(mixedOffset));
|
||||
if (failed(compiledOffset))
|
||||
return mlir::failure();
|
||||
CompiledIndexExpr scaleExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = sourceStride * getElementTypeSizeInBytes(subviewType.getElementType());
|
||||
scaleExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Mul;
|
||||
expr.operands = {*compiledOffset, scaleExpr};
|
||||
operandExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Add;
|
||||
expr.operands = {offsetExpr, operandExpr};
|
||||
offsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
|
||||
CompiledIndexExpr constantExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = constantByteOffset;
|
||||
constantExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Add;
|
||||
expr.operands = {constantExpr, offsetExpr};
|
||||
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
constantByteOffset = 0;
|
||||
}
|
||||
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(definingOp)) {
|
||||
value = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapseOp = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(definingOp)) {
|
||||
value = collapseOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expandOp = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(definingOp)) {
|
||||
value = expandOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (mlir::isa<mlir::memref::AllocOp, mlir::memref::GetGlobalOp>(definingOp)) {
|
||||
if (constantByteOffset != 0) {
|
||||
CompiledIndexExpr constantExpr;
|
||||
{
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = constantByteOffset;
|
||||
constantExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
if (byteOffsetExpr.node->kind == CompiledIndexExprNode::Kind::Constant && byteOffsetExpr.node->constant == 0)
|
||||
byteOffsetExpr = constantExpr;
|
||||
else {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Add;
|
||||
expr.operands = {constantExpr, byteOffsetExpr};
|
||||
byteOffsetExpr = makeCompiledIndexExpr(std::move(expr));
|
||||
}
|
||||
}
|
||||
return CompiledAddressExpr {value, byteOffsetExpr};
|
||||
}
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge) {
|
||||
return resolveIndexValueImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value) {
|
||||
return resolveContiguousAddressImpl(value, nullptr);
|
||||
}
|
||||
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value) { return compileIndexValueImpl(value); }
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
||||
const StaticValueKnowledge& knowledge) {
|
||||
@@ -264,4 +787,21 @@ mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledg
|
||||
return resolveLoopCarriedAliasImpl(value, &knowledge);
|
||||
}
|
||||
|
||||
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value) {
|
||||
return compileContiguousAddressExprImpl(value);
|
||||
}
|
||||
|
||||
llvm::FailureOr<int64_t> CompiledIndexExpr::evaluate(const StaticValueKnowledge& knowledge) const {
|
||||
return evaluateCompiledIndexExpr(*this, knowledge);
|
||||
}
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> CompiledAddressExpr::evaluate(const StaticValueKnowledge& knowledge,
|
||||
std::optional<unsigned> lane) const {
|
||||
(void) lane;
|
||||
auto resolvedOffset = byteOffset.evaluate(knowledge);
|
||||
if (failed(resolvedOffset))
|
||||
return mlir::failure();
|
||||
return ResolvedContiguousAddress {base, *resolvedOffset};
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/// Describes a value as a base addressable object plus a statically known
|
||||
@@ -23,21 +27,68 @@ struct StaticValueKnowledge {
|
||||
StaticValueKnowledge() {}
|
||||
};
|
||||
|
||||
struct CompiledIndexExprNode;
|
||||
|
||||
struct CompiledIndexExpr {
|
||||
std::shared_ptr<CompiledIndexExprNode> node;
|
||||
|
||||
CompiledIndexExpr() = default;
|
||||
explicit CompiledIndexExpr(std::shared_ptr<CompiledIndexExprNode> node)
|
||||
: node(std::move(node)) {}
|
||||
|
||||
llvm::FailureOr<int64_t> evaluate(const StaticValueKnowledge& knowledge) const;
|
||||
};
|
||||
|
||||
struct CompiledIndexExprNode {
|
||||
enum class Kind {
|
||||
Constant,
|
||||
Symbol,
|
||||
Add,
|
||||
Sub,
|
||||
Mul,
|
||||
DivUI,
|
||||
DivSI,
|
||||
RemUI,
|
||||
RemSI,
|
||||
MinUI,
|
||||
CmpI,
|
||||
Select,
|
||||
ConstantGlobalLoad
|
||||
};
|
||||
|
||||
Kind kind = Kind::Constant;
|
||||
int64_t constant = 0;
|
||||
mlir::Value symbol;
|
||||
mlir::arith::CmpIPredicate predicate = mlir::arith::CmpIPredicate::eq;
|
||||
mlir::memref::GlobalOp globalOp;
|
||||
llvm::SmallVector<int64_t, 4> globalStrides;
|
||||
llvm::SmallVector<CompiledIndexExpr, 4> operands;
|
||||
};
|
||||
|
||||
struct CompiledAddressExpr {
|
||||
mlir::Value base;
|
||||
CompiledIndexExpr byteOffset;
|
||||
|
||||
llvm::FailureOr<ResolvedContiguousAddress> evaluate(const StaticValueKnowledge& knowledge,
|
||||
std::optional<unsigned> lane) const;
|
||||
};
|
||||
|
||||
mlir::memref::GlobalOp lookupGlobalForGetGlobal(mlir::ModuleOp moduleOp, mlir::memref::GetGlobalOp getGlobalOp);
|
||||
|
||||
/// Resolves a value to contiguous backing storage when that storage can be
|
||||
/// proven statically from aliases, DPS ties, casts, and subviews.
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value);
|
||||
llvm::FailureOr<ResolvedContiguousAddress> resolveContiguousAddress(mlir::Value value,
|
||||
const StaticValueKnowledge& knowledge);
|
||||
const StaticValueKnowledge& knowledge = {});
|
||||
|
||||
/// Statically evaluates index-like SSA values, including simple integer
|
||||
/// arithmetic and loop facts recorded in `knowledge`.
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value);
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||
llvm::FailureOr<int64_t> resolveIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {});
|
||||
llvm::FailureOr<CompiledIndexExpr> compileIndexExpr(mlir::Value value);
|
||||
|
||||
/// Follows alias, view, and DPS chains to recover the backing value of a
|
||||
/// loop-carried memref/result.
|
||||
mlir::Value resolveLoopCarriedAlias(mlir::Value value, const StaticValueKnowledge& knowledge);
|
||||
|
||||
llvm::FailureOr<CompiledAddressExpr> compileContiguousAddressExpr(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -0,0 +1,182 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "AffineUtils.hpp"
|
||||
#include "ConstantUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
static FailureOr<int64_t> floorDivSigned(int64_t lhs, int64_t rhs) {
|
||||
if (rhs <= 0)
|
||||
return failure();
|
||||
|
||||
int64_t quotient = lhs / rhs;
|
||||
int64_t remainder = lhs % rhs;
|
||||
if (remainder != 0 && lhs < 0)
|
||||
--quotient;
|
||||
return quotient;
|
||||
}
|
||||
|
||||
static FailureOr<int64_t> ceilDivSigned(int64_t lhs, int64_t rhs) {
|
||||
if (rhs <= 0)
|
||||
return failure();
|
||||
|
||||
int64_t quotient = lhs / rhs;
|
||||
int64_t remainder = lhs % rhs;
|
||||
if (remainder != 0 && lhs > 0)
|
||||
++quotient;
|
||||
return quotient;
|
||||
}
|
||||
|
||||
Value createOrFoldAffineApply(
|
||||
RewriterBase& rewriter, Location loc, AffineMap map, ValueRange operands, Operation* constantAnchor) {
|
||||
assert(constantAnchor && "expected a valid constant anchor");
|
||||
assert(map.getNumResults() == 1 && "affine.apply expects a single-result affine map");
|
||||
|
||||
SmallVector<Attribute> operandConstants;
|
||||
operandConstants.reserve(operands.size());
|
||||
for (Value operand : operands) {
|
||||
std::optional<int64_t> constantValue = matchConstantIndexValue(operand);
|
||||
if (!constantValue)
|
||||
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
|
||||
operandConstants.push_back(rewriter.getIndexAttr(*constantValue));
|
||||
}
|
||||
|
||||
SmallVector<Attribute> foldedResults;
|
||||
if (succeeded(map.constantFold(operandConstants, foldedResults)) && foldedResults.size() == 1)
|
||||
if (auto constantResult = dyn_cast<IntegerAttr>(foldedResults.front()))
|
||||
return getOrCreateIndexConstant(rewriter, constantAnchor, constantResult.getInt());
|
||||
|
||||
return affine::AffineApplyOp::create(rewriter, loc, map, operands).getResult();
|
||||
}
|
||||
|
||||
Value createOrFoldAffineApply(
|
||||
RewriterBase& rewriter, Location loc, AffineExpr expr, ValueRange dims, Operation* constantAnchor) {
|
||||
AffineMap map = AffineMap::get(/*dimCount=*/dims.size(), /*symbolCount=*/0, expr);
|
||||
return createOrFoldAffineApply(rewriter, loc, map, dims, constantAnchor);
|
||||
}
|
||||
|
||||
Value affineMulConst(RewriterBase& rewriter, Location loc, Value value, int64_t multiplier, Operation* constantAnchor) {
|
||||
assert(constantAnchor && "expected a valid constant anchor");
|
||||
if (multiplier == 0)
|
||||
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
|
||||
if (multiplier == 1)
|
||||
return value;
|
||||
|
||||
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
|
||||
return createOrFoldAffineApply(rewriter, loc, d0 * multiplier, ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
Value affineModConst(RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
|
||||
assert(constantAnchor && "expected a valid constant anchor");
|
||||
assert(divisor > 0 && "expected a positive affine.mod divisor");
|
||||
if (divisor == 1)
|
||||
return getOrCreateIndexConstant(rewriter, constantAnchor, 0);
|
||||
|
||||
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
|
||||
return createOrFoldAffineApply(rewriter, loc, d0 % divisor, ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
Value affineFloorDivConst(
|
||||
RewriterBase& rewriter, Location loc, Value value, int64_t divisor, Operation* constantAnchor) {
|
||||
assert(constantAnchor && "expected a valid constant anchor");
|
||||
assert(divisor > 0 && "expected a positive affine.floor_div divisor");
|
||||
if (divisor == 1)
|
||||
return value;
|
||||
|
||||
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
|
||||
return createOrFoldAffineApply(rewriter, loc, d0.floorDiv(divisor), ValueRange {value}, constantAnchor);
|
||||
}
|
||||
|
||||
FailureOr<int64_t> evaluateAffineExpr(AffineExpr expr, ArrayRef<int64_t> dims, ArrayRef<int64_t> symbols) {
|
||||
if (auto constant = dyn_cast<AffineConstantExpr>(expr))
|
||||
return constant.getValue();
|
||||
if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
|
||||
unsigned position = dim.getPosition();
|
||||
if (position >= dims.size())
|
||||
return failure();
|
||||
return dims[position];
|
||||
}
|
||||
if (auto symbol = dyn_cast<AffineSymbolExpr>(expr)) {
|
||||
unsigned position = symbol.getPosition();
|
||||
if (position >= symbols.size())
|
||||
return failure();
|
||||
return symbols[position];
|
||||
}
|
||||
|
||||
auto binary = dyn_cast<AffineBinaryOpExpr>(expr);
|
||||
if (!binary)
|
||||
return failure();
|
||||
|
||||
FailureOr<int64_t> lhs = evaluateAffineExpr(binary.getLHS(), dims, symbols);
|
||||
FailureOr<int64_t> rhs = evaluateAffineExpr(binary.getRHS(), dims, symbols);
|
||||
if (failed(lhs) || failed(rhs))
|
||||
return failure();
|
||||
|
||||
switch (binary.getKind()) {
|
||||
case AffineExprKind::Add: return *lhs + *rhs;
|
||||
case AffineExprKind::Mul: return *lhs * *rhs;
|
||||
case AffineExprKind::FloorDiv: return floorDivSigned(*lhs, *rhs);
|
||||
case AffineExprKind::CeilDiv: return ceilDivSigned(*lhs, *rhs);
|
||||
case AffineExprKind::Mod: {
|
||||
FailureOr<int64_t> div = floorDivSigned(*lhs, *rhs);
|
||||
if (failed(div))
|
||||
return failure();
|
||||
return *lhs - *div * *rhs;
|
||||
}
|
||||
default: return failure();
|
||||
}
|
||||
}
|
||||
|
||||
FailureOr<int64_t> evaluateSingleResultAffineMap(AffineMap map, ArrayRef<int64_t> operands) {
|
||||
if (map.getNumResults() != 1 || operands.size() != map.getNumInputs())
|
||||
return failure();
|
||||
|
||||
ArrayRef<int64_t> dims(operands.data(), map.getNumDims());
|
||||
ArrayRef<int64_t> symbols(operands.data() + map.getNumDims(), map.getNumSymbols());
|
||||
return evaluateAffineExpr(map.getResult(0), dims, symbols);
|
||||
}
|
||||
|
||||
FailureOr<int64_t> evaluateAffineApply(affine::AffineApplyOp affineApply, IndexValueResolver resolver) {
|
||||
SmallVector<int64_t, 4> operands;
|
||||
operands.reserve(affineApply.getMapOperands().size());
|
||||
for (Value operand : affineApply.getMapOperands()) {
|
||||
FailureOr<int64_t> folded = resolver(operand);
|
||||
if (failed(folded))
|
||||
return failure();
|
||||
operands.push_back(*folded);
|
||||
}
|
||||
|
||||
return evaluateSingleResultAffineMap(affineApply.getAffineMap(), operands);
|
||||
}
|
||||
|
||||
bool isSingleResultSymbolFreeAffineMap(AffineMap map) { return map.getNumResults() == 1 && map.getNumSymbols() == 0; }
|
||||
|
||||
bool isDimAndConstantAffineExpr(AffineExpr expr) {
|
||||
switch (expr.getKind()) {
|
||||
case AffineExprKind::Constant:
|
||||
case AffineExprKind::DimId: return true;
|
||||
case AffineExprKind::SymbolId: return false;
|
||||
case AffineExprKind::Add: {
|
||||
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
return isDimAndConstantAffineExpr(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS());
|
||||
}
|
||||
case AffineExprKind::Mul: {
|
||||
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
return (isa<AffineConstantExpr>(binaryExpr.getLHS()) && isDimAndConstantAffineExpr(binaryExpr.getRHS()))
|
||||
|| (isa<AffineConstantExpr>(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS()));
|
||||
}
|
||||
case AffineExprKind::FloorDiv:
|
||||
case AffineExprKind::CeilDiv:
|
||||
case AffineExprKind::Mod: {
|
||||
auto binaryExpr = cast<AffineBinaryOpExpr>(expr);
|
||||
return isa<AffineConstantExpr>(binaryExpr.getRHS()) && isDimAndConstantAffineExpr(binaryExpr.getLHS());
|
||||
}
|
||||
}
|
||||
|
||||
llvm_unreachable("unexpected affine expression kind");
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,55 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/FunctionExtras.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
using IndexValueResolver = llvm::function_ref<llvm::FailureOr<int64_t>(mlir::Value)>;
|
||||
|
||||
mlir::Value createOrFoldAffineApply(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::AffineMap map,
|
||||
mlir::ValueRange operands,
|
||||
mlir::Operation* constantAnchor);
|
||||
|
||||
mlir::Value createOrFoldAffineApply(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::AffineExpr expr,
|
||||
mlir::ValueRange dims,
|
||||
mlir::Operation* constantAnchor);
|
||||
|
||||
mlir::Value affineMulConst(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value value,
|
||||
int64_t multiplier,
|
||||
mlir::Operation* constantAnchor);
|
||||
|
||||
mlir::Value affineModConst(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value value,
|
||||
int64_t divisor,
|
||||
mlir::Operation* constantAnchor);
|
||||
|
||||
mlir::Value affineFloorDivConst(mlir::RewriterBase& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value value,
|
||||
int64_t divisor,
|
||||
mlir::Operation* constantAnchor);
|
||||
|
||||
llvm::FailureOr<int64_t>
|
||||
evaluateAffineExpr(mlir::AffineExpr expr, llvm::ArrayRef<int64_t> dims, llvm::ArrayRef<int64_t> symbols = {});
|
||||
|
||||
llvm::FailureOr<int64_t> evaluateSingleResultAffineMap(mlir::AffineMap map, llvm::ArrayRef<int64_t> operands);
|
||||
|
||||
llvm::FailureOr<int64_t> evaluateAffineApply(mlir::affine::AffineApplyOp affineApply, IndexValueResolver resolver);
|
||||
|
||||
bool isSingleResultSymbolFreeAffineMap(mlir::AffineMap map);
|
||||
|
||||
bool isDimAndConstantAffineExpr(mlir::AffineExpr expr);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,32 @@
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto coreIdsAttr = coreBatchOp->getAttrOfType<mlir::DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
|
||||
return llvm::SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
}
|
||||
|
||||
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
|
||||
llvm::SmallVector<int32_t> laneCoreIds;
|
||||
laneCoreIds.reserve(coreIds.size() / laneCount);
|
||||
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
|
||||
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
|
||||
return laneCoreIds;
|
||||
}
|
||||
|
||||
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex) {
|
||||
if (mlir::isa<pim::PimMemCopyHostToDevOp>(op))
|
||||
return operandIndex == 3;
|
||||
if (mlir::isa<pim::PimMemCopyDevToHostOp>(op))
|
||||
return operandIndex == 2;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool isExplicitDevToHostTargetOperand(mlir::Operation* op, unsigned operandIndex) {
|
||||
return mlir::isa<pim::PimMemCopyDevToHostOp>(op) && operandIndex == 2;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp);
|
||||
|
||||
llvm::SmallVector<int32_t> getLaneChunkCoreIds(llvm::ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane);
|
||||
|
||||
bool isExplicitHostMemCopyOperand(mlir::Operation* op, unsigned operandIndex);
|
||||
|
||||
bool isExplicitDevToHostTargetOperand(mlir::Operation* op, unsigned operandIndex);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,93 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "ConstantUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
Block* getConstantInsertionBlock(Operation* anchorOp) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
|
||||
if (auto funcOp = dyn_cast<func::FuncOp>(anchorOp))
|
||||
return &funcOp.getBody().front();
|
||||
if (auto funcOp = anchorOp->getParentOfType<func::FuncOp>())
|
||||
return &funcOp.getBody().front();
|
||||
if (auto moduleOp = dyn_cast<ModuleOp>(anchorOp))
|
||||
return moduleOp.getBody();
|
||||
if (auto moduleOp = anchorOp->getParentOfType<ModuleOp>())
|
||||
return moduleOp.getBody();
|
||||
return anchorOp->getBlock();
|
||||
}
|
||||
|
||||
Value getOrCreateConstant(OperationFolder& folder, Operation* anchorOp, Attribute value, Type type) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
Block* hostBlock = getConstantInsertionBlock(anchorOp);
|
||||
for (Operation& op : *hostBlock) {
|
||||
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
|
||||
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
|
||||
continue;
|
||||
return constantOp.getResult();
|
||||
}
|
||||
|
||||
auto* arithDialect = anchorOp->getContext()->getOrLoadDialect<arith::ArithDialect>();
|
||||
return folder.getOrCreateConstant(hostBlock, arithDialect, value, type);
|
||||
}
|
||||
|
||||
Value getOrCreateConstant(RewriterBase& rewriter, Operation* anchorOp, Attribute value, Type type) {
|
||||
assert(anchorOp && "expected a valid anchor operation");
|
||||
Block* hostBlock = getConstantInsertionBlock(anchorOp);
|
||||
for (Operation& op : *hostBlock) {
|
||||
auto constantOp = dyn_cast<arith::ConstantOp>(&op);
|
||||
if (!constantOp || constantOp.getType() != type || constantOp.getValue() != value)
|
||||
continue;
|
||||
return constantOp.getResult();
|
||||
}
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(hostBlock);
|
||||
return arith::ConstantOp::create(rewriter, anchorOp->getLoc(), type, cast<TypedAttr>(value)).getResult();
|
||||
}
|
||||
|
||||
Value getOrCreateConstantLike(OperationFolder& folder, arith::ConstantOp constantOp) {
|
||||
return getOrCreateConstant(folder, constantOp.getOperation(), constantOp.getValue(), constantOp.getType());
|
||||
}
|
||||
|
||||
Value getOrCreateIndexConstant(OperationFolder& folder, Operation* anchorOp, int64_t value) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateConstant(folder, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
|
||||
}
|
||||
|
||||
Value getOrCreateIndexConstant(RewriterBase& rewriter, Operation* anchorOp, int64_t value) {
|
||||
Builder builder(anchorOp->getContext());
|
||||
return getOrCreateConstant(rewriter, anchorOp, builder.getIndexAttr(value), builder.getIndexType());
|
||||
}
|
||||
|
||||
std::optional<int64_t> matchConstantIndexValue(Value value) {
|
||||
if (!value || !value.getType().isIndex())
|
||||
return std::nullopt;
|
||||
|
||||
if (auto constant = value.getDefiningOp<arith::ConstantIndexOp>())
|
||||
return constant.value();
|
||||
|
||||
if (auto constant = value.getDefiningOp<arith::ConstantOp>())
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>(constant.getValue()); intAttr && intAttr.getType().isIndex())
|
||||
return intAttr.getInt();
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
std::optional<int64_t> matchConstantIndexValue(OpFoldResult value) {
|
||||
if (auto attr = dyn_cast<Attribute>(value))
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>(attr); intAttr && intAttr.getType().isIndex())
|
||||
return intAttr.getInt();
|
||||
if (auto operand = dyn_cast<Value>(value))
|
||||
return matchConstantIndexValue(operand);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,30 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::Block* getConstantInsertionBlock(mlir::Operation* anchorOp);
|
||||
|
||||
mlir::Value
|
||||
getOrCreateConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, mlir::Attribute value, mlir::Type type);
|
||||
|
||||
mlir::Value
|
||||
getOrCreateConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, mlir::Attribute value, mlir::Type type);
|
||||
|
||||
mlir::Value getOrCreateConstantLike(mlir::OperationFolder& folder, mlir::arith::ConstantOp constantOp);
|
||||
|
||||
mlir::Value getOrCreateIndexConstant(mlir::OperationFolder& folder, mlir::Operation* anchorOp, int64_t value);
|
||||
|
||||
mlir::Value getOrCreateIndexConstant(mlir::RewriterBase& rewriter, mlir::Operation* anchorOp, int64_t value);
|
||||
|
||||
std::optional<int64_t> matchConstantIndexValue(mlir::Value value);
|
||||
|
||||
std::optional<int64_t> matchConstantIndexValue(mlir::OpFoldResult value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,25 +1,37 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool isCoreStaticAddressOp(mlir::Operation* op) {
|
||||
return mlir::isa<mlir::arith::ConstantOp,
|
||||
if (mlir::isa<mlir::arith::ConstantOp,
|
||||
mlir::arith::AddIOp,
|
||||
mlir::arith::SubIOp,
|
||||
mlir::arith::MulIOp,
|
||||
mlir::arith::DivUIOp,
|
||||
mlir::arith::DivSIOp,
|
||||
mlir::arith::MinUIOp,
|
||||
mlir::arith::RemUIOp,
|
||||
mlir::arith::RemSIOp,
|
||||
mlir::arith::IndexCastOp,
|
||||
mlir::arith::CmpIOp,
|
||||
mlir::memref::AllocOp,
|
||||
mlir::memref::SubViewOp,
|
||||
mlir::memref::CastOp,
|
||||
mlir::memref::CollapseShapeOp,
|
||||
mlir::memref::ExpandShapeOp>(op);
|
||||
mlir::memref::ExpandShapeOp>(op))
|
||||
return true;
|
||||
|
||||
if (auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(op))
|
||||
return selectOp.getType().isIntOrIndex();
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
@@ -30,6 +42,9 @@ walkPimCoreBlock(mlir::Block& block,
|
||||
for (mlir::Operation& op : block) {
|
||||
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||
continue;
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
|
||||
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
|
||||
continue;
|
||||
|
||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
|
||||
mlir::Block& loopBody = forOp.getRegion().front();
|
||||
@@ -65,4 +80,58 @@ walkPimCoreBlock(mlir::Block& block,
|
||||
return mlir::success(!hasFailure);
|
||||
}
|
||||
|
||||
mlir::LogicalResult walkPimCoreBlockStructurally(
|
||||
mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback) {
|
||||
bool hasFailure = false;
|
||||
for (mlir::Operation& op : block) {
|
||||
if (mlir::isa<pim::PimHaltOp, mlir::scf::YieldOp>(op) || isCoreStaticAddressOp(&op))
|
||||
continue;
|
||||
if (auto loadOp = mlir::dyn_cast<mlir::memref::LoadOp>(op);
|
||||
loadOp && succeeded(resolveIndexValue(loadOp.getResult(), knowledge)))
|
||||
continue;
|
||||
|
||||
if (auto forOp = mlir::dyn_cast<mlir::scf::ForOp>(op)) {
|
||||
mlir::Block& loopBody = forOp.getRegion().front();
|
||||
auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge);
|
||||
auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge);
|
||||
auto step = resolveIndexValue(forOp.getStep(), knowledge);
|
||||
if (failed(lowerBound) || failed(upperBound) || failed(step)) {
|
||||
forOp.emitOpError("requires statically evaluable scf.for bounds for PIM verification");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
if (*step <= 0) {
|
||||
forOp.emitOpError("requires positive scf.for step for PIM verification");
|
||||
hasFailure = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t, 2> samples;
|
||||
if (*lowerBound < *upperBound) {
|
||||
samples.push_back(*lowerBound);
|
||||
int64_t last = *lowerBound + ((*upperBound - 1 - *lowerBound) / *step) * *step;
|
||||
if (last != *lowerBound)
|
||||
samples.push_back(last);
|
||||
}
|
||||
|
||||
for (int64_t inductionValue : samples) {
|
||||
StaticValueKnowledge loopKnowledge = knowledge;
|
||||
loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue;
|
||||
for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), forOp.getInitArgs()))
|
||||
loopKnowledge.aliases[iterArg] = iterValue;
|
||||
|
||||
if (failed(walkPimCoreBlockStructurally(loopBody, loopKnowledge, callback)))
|
||||
hasFailure = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (failed(callback(op, knowledge)))
|
||||
hasFailure = true;
|
||||
}
|
||||
return mlir::success(!hasFailure);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -21,4 +21,12 @@ walkPimCoreBlock(mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
||||
|
||||
/// Walks a `pim.core`-like body structurally for verification without
|
||||
/// enumerating full loop trip counts. Loop bounds must still be statically
|
||||
/// evaluable so address resolution remains well-defined.
|
||||
mlir::LogicalResult walkPimCoreBlockStructurally(
|
||||
mlir::Block& block,
|
||||
const StaticValueKnowledge& knowledge,
|
||||
llvm::function_ref<mlir::LogicalResult(mlir::Operation&, const StaticValueKnowledge&)> callback);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
|
||||
@@ -35,6 +36,30 @@ int64_t getNumElements(llvm::ArrayRef<int64_t> shape) {
|
||||
return numElements;
|
||||
}
|
||||
|
||||
bool hasByteSizedElementType(mlir::Type elementType) {
|
||||
if (mlir::isa<mlir::IndexType>(elementType))
|
||||
return true;
|
||||
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
|
||||
return intType.getWidth() > 0 && intType.getWidth() % 8 == 0;
|
||||
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
|
||||
return floatType.getWidth() > 0 && floatType.getWidth() % 8 == 0;
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t getElementTypeSizeInBytes(mlir::Type elementType) {
|
||||
if (mlir::isa<mlir::IndexType>(elementType))
|
||||
return mlir::IndexType::kInternalStorageBitWidth / 8;
|
||||
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType))
|
||||
return static_cast<size_t>(intType.getWidth() / 8);
|
||||
if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType))
|
||||
return static_cast<size_t>(floatType.getWidth() / 8);
|
||||
llvm_unreachable("expected byte-sized integer, float, or index element type");
|
||||
}
|
||||
|
||||
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType) {
|
||||
return static_cast<size_t>(shapedType.getNumElements()) * getElementTypeSizeInBytes(shapedType.getElementType());
|
||||
}
|
||||
|
||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
llvm::ArrayRef<int64_t> offsets,
|
||||
llvm::ArrayRef<int64_t> sizes,
|
||||
@@ -86,4 +111,56 @@ bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
|
||||
llvm::ArrayRef<int64_t> staticSizes,
|
||||
llvm::ArrayRef<int64_t> staticStrides) {
|
||||
if (sourceShape.size() != mixedOffsets.size() || sourceShape.size() != staticSizes.size()
|
||||
|| sourceShape.size() != staticStrides.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (llvm::any_of(staticStrides, [](int64_t stride) { return stride != 1; }))
|
||||
return false;
|
||||
|
||||
auto reversedTriples =
|
||||
llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(mixedOffsets), llvm::reverse(staticSizes));
|
||||
|
||||
auto firstNonZeroOrDynamicOffset = llvm::find_if(reversedTriples, [](auto triple) {
|
||||
auto [_sourceDim, offset, _size] = triple;
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset))
|
||||
return mlir::cast<mlir::IntegerAttr>(attr).getInt() != 0;
|
||||
return true;
|
||||
});
|
||||
|
||||
if (firstNonZeroOrDynamicOffset != reversedTriples.end()) {
|
||||
auto [sourceDim, offset, size] = *firstNonZeroOrDynamicOffset;
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
|
||||
int64_t staticOffset = mlir::cast<mlir::IntegerAttr>(attr).getInt();
|
||||
if (size > sourceDim - staticOffset)
|
||||
return false;
|
||||
}
|
||||
|
||||
++firstNonZeroOrDynamicOffset;
|
||||
for (auto it = firstNonZeroOrDynamicOffset; it != reversedTriples.end(); ++it)
|
||||
if (std::get<2>(*it) != 1)
|
||||
return false;
|
||||
}
|
||||
|
||||
auto reversedSizes = llvm::zip_equal(llvm::reverse(sourceShape), llvm::reverse(staticSizes));
|
||||
auto firstDifferentSize = llvm::find_if(reversedSizes, [](auto pair) {
|
||||
auto [sourceDim, size] = pair;
|
||||
return size != sourceDim;
|
||||
});
|
||||
|
||||
if (firstDifferentSize != reversedSizes.end()) {
|
||||
++firstDifferentSize;
|
||||
for (auto it = firstDifferentSize; it != reversedSizes.end(); ++it)
|
||||
if (std::get<1>(*it) != 1)
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::SmallVector<int64_t> computeRowMajorStrides(llvm::ArrayRef<int64_t> shape);
|
||||
@@ -14,9 +20,20 @@ int64_t linearizeIndex(llvm::ArrayRef<int64_t> indices, llvm::ArrayRef<int64_t>
|
||||
|
||||
int64_t getNumElements(llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
bool hasByteSizedElementType(mlir::Type elementType);
|
||||
|
||||
size_t getElementTypeSizeInBytes(mlir::Type elementType);
|
||||
|
||||
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
||||
|
||||
bool isMemoryContiguous(llvm::ArrayRef<int64_t> srcShape,
|
||||
llvm::ArrayRef<int64_t> offsets,
|
||||
llvm::ArrayRef<int64_t> sizes,
|
||||
llvm::ArrayRef<int64_t> strides);
|
||||
|
||||
bool isContiguousSubviewWithDynamicOffsets(llvm::ArrayRef<int64_t> sourceShape,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> mixedOffsets,
|
||||
llvm::ArrayRef<int64_t> staticSizes,
|
||||
llvm::ArrayRef<int64_t> staticStrides);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -31,6 +31,19 @@ Value stripMemRefViewOps(Value value) {
|
||||
}
|
||||
}
|
||||
|
||||
Value stripMemRefAddressingOps(Value value) {
|
||||
while (true) {
|
||||
if (auto subviewOp = value.getDefiningOp<memref::SubViewOp>()) {
|
||||
value = subviewOp.getSource();
|
||||
continue;
|
||||
}
|
||||
Value strippedValue = stripMemRefViewOps(value);
|
||||
if (strippedValue == value)
|
||||
return value;
|
||||
value = strippedValue;
|
||||
}
|
||||
}
|
||||
|
||||
bool hasAllStaticSubviewParts(memref::SubViewOp subview) {
|
||||
return llvm::all_of(subview.getStaticOffsets(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
&& llvm::all_of(subview.getStaticSizes(), [](int64_t value) { return !ShapedType::isDynamic(value); })
|
||||
@@ -81,4 +94,13 @@ FailureOr<SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo&
|
||||
return staticOffsets;
|
||||
}
|
||||
|
||||
bool isMemRefBaseAddressableValue(Value value) {
|
||||
value = stripMemRefAddressingOps(value);
|
||||
if (isa<BlockArgument>(value))
|
||||
return true;
|
||||
|
||||
Operation* defOp = value.getDefiningOp();
|
||||
return defOp && isa<memref::AllocOp, memref::GetGlobalOp>(defOp);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -20,6 +20,8 @@ mlir::Value stripMemRefCasts(mlir::Value value);
|
||||
|
||||
mlir::Value stripMemRefViewOps(mlir::Value value);
|
||||
|
||||
mlir::Value stripMemRefAddressingOps(mlir::Value value);
|
||||
|
||||
bool hasAllStaticSubviewParts(mlir::memref::SubViewOp subview);
|
||||
|
||||
llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||
@@ -27,4 +29,6 @@ llvm::FailureOr<StaticSubviewInfo> getStaticSubviewInfo(mlir::Value value);
|
||||
/// Returns the offsets in `info` as int64_t, failing if any offset is dynamic.
|
||||
llvm::FailureOr<llvm::SmallVector<int64_t>> getStaticSubviewOffsets(const StaticSubviewInfo& info);
|
||||
|
||||
bool isMemRefBaseAddressableValue(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
@@ -19,29 +25,57 @@ void markWeightAlways(mlir::Operation* op) {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
CompiledIndexExpr makeConstantExpr(int64_t constant) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = CompiledIndexExprNode::Kind::Constant;
|
||||
expr.constant = constant;
|
||||
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::move(expr)));
|
||||
}
|
||||
|
||||
CompiledIndexExpr makeBinaryExpr(CompiledIndexExprNode::Kind kind, CompiledIndexExpr lhs, CompiledIndexExpr rhs) {
|
||||
CompiledIndexExprNode expr;
|
||||
expr.kind = kind;
|
||||
expr.operands = {std::move(lhs), std::move(rhs)};
|
||||
return CompiledIndexExpr(std::make_shared<CompiledIndexExprNode>(std::move(expr)));
|
||||
}
|
||||
|
||||
CompiledIndexExpr addExpr(CompiledIndexExpr lhs, CompiledIndexExpr rhs) {
|
||||
return makeBinaryExpr(CompiledIndexExprNode::Kind::Add, std::move(lhs), std::move(rhs));
|
||||
}
|
||||
|
||||
CompiledIndexExpr mulExpr(CompiledIndexExpr lhs, int64_t rhs) {
|
||||
return makeBinaryExpr(CompiledIndexExprNode::Kind::Mul, std::move(lhs), makeConstantExpr(rhs));
|
||||
}
|
||||
|
||||
template <typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
if (!weightArg)
|
||||
return false;
|
||||
bool found = false;
|
||||
parentOp.walk([&](mlir::Operation* op) {
|
||||
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
|
||||
found |= mvmOp.getWeightIndex() == weightIndex;
|
||||
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
||||
found |= vmmOp.getWeightIndex() == weightIndex;
|
||||
if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
||||
found |= vmmOp.getWeight() == *weightArg;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||
template <typename VMMOpTy, typename ParentOpTy>
|
||||
void walkVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpOperand&)> callback) {
|
||||
auto weights = parentOp.getWeights();
|
||||
llvm::SmallSet<unsigned, 8> visited;
|
||||
auto walkWeightIndex = [&](unsigned weightIndex) {
|
||||
if (weightIndex < weights.size() && visited.insert(weightIndex).second)
|
||||
auto walkWeight = [&](mlir::Value weight) {
|
||||
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
|
||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
if (!weightArg || *weightArg != weight)
|
||||
continue;
|
||||
if (visited.insert(weightIndex).second)
|
||||
callback(parentOp->getOpOperand(weightIndex));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); });
|
||||
parentOp.walk([&](VMMOpTy op) { walkWeight(op.getWeight()); });
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@@ -54,7 +88,7 @@ bool isSpatialMvmVmmWeightUse(mlir::OpOperand& use) {
|
||||
if (!computeOp || operandIndex >= computeOp.getWeights().size())
|
||||
return false;
|
||||
|
||||
return hasMvmVmmWeightUse<spatial::SpatMVMOp, spatial::SpatVMMOp>(computeOp, operandIndex);
|
||||
return hasVmmWeightUse<spatial::SpatVMMOp>(computeOp, operandIndex);
|
||||
}
|
||||
|
||||
bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
|
||||
@@ -76,8 +110,8 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value) {
|
||||
return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self);
|
||||
if (auto collapseShapeOp = mlir::dyn_cast<mlir::tensor::CollapseShapeOp>(user))
|
||||
return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self);
|
||||
if (auto transposeOp = mlir::dyn_cast<mlir::ONNXTransposeOp>(user))
|
||||
return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self);
|
||||
if (auto transposeOp = mlir::dyn_cast<mlir::linalg::TransposeOp>(user))
|
||||
return transposeOp.getInput() == currentValue && self(transposeOp.getResult()[0], self);
|
||||
|
||||
return false;
|
||||
});
|
||||
@@ -90,19 +124,149 @@ void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir
|
||||
assert(root && "expected valid root op");
|
||||
root->walk([&](pim::PimCoreOp coreOp) {
|
||||
coreOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
auto weights = coreOp.getWeights();
|
||||
unsigned weightIndex = vmmOp.getWeightIndex();
|
||||
if (weightIndex < weights.size())
|
||||
callback(coreOp->getOpOperand(weightIndex));
|
||||
if (auto weightIndex = resolveWeightIndex(coreOp.getOperation(), vmmOp.getWeight()))
|
||||
callback(coreOp->getOpOperand(*weightIndex));
|
||||
});
|
||||
});
|
||||
root->walk([&](pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto weights = coreBatchOp.getWeights();
|
||||
for (auto weight : weights)
|
||||
for (mlir::OpOperand& use : weight.getUses())
|
||||
if (use.getOwner() == coreBatchOp.getOperation())
|
||||
callback(use);
|
||||
coreBatchOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
if (auto weightIndex = resolveWeightIndex(coreBatchOp.getOperation(), vmmOp.getWeight()))
|
||||
callback(coreBatchOp->getOpOperand(*weightIndex));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight) {
|
||||
weight = stripMemRefAddressingOps(weight);
|
||||
|
||||
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreOp.getWeights().size(); ++weightIndex)
|
||||
if (coreOp.getWeightArgument(weightIndex) == weight)
|
||||
return weightIndex;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
|
||||
for (unsigned weightIndex = 0; weightIndex < coreBatchOp.getWeights().size(); ++weightIndex)
|
||||
if (coreBatchOp.getWeightArgument(weightIndex) == weight)
|
||||
return weightIndex;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
llvm::FailureOr<ResolvedWeightView>
|
||||
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge) {
|
||||
llvm::SmallVector<mlir::Operation*> viewOps;
|
||||
mlir::Value current = weight;
|
||||
|
||||
while (true) {
|
||||
if (auto defOp = current.getDefiningOp()) {
|
||||
if (auto getGlobalOp = mlir::dyn_cast<mlir::memref::GetGlobalOp>(defOp)) {
|
||||
auto moduleOp = weightOwner ? weightOwner->getParentOfType<mlir::ModuleOp>() : mlir::ModuleOp {};
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getInitialValue())
|
||||
return mlir::failure();
|
||||
|
||||
auto denseAttr = mlir::dyn_cast<mlir::DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return mlir::failure();
|
||||
|
||||
ResolvedWeightView view;
|
||||
view.globalOp = globalOp;
|
||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
|
||||
CompiledIndexExpr offsetExpr = makeConstantExpr(0);
|
||||
for (mlir::Operation* viewOp : llvm::reverse(viewOps)) {
|
||||
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(viewOp)) {
|
||||
llvm::SmallVector<int64_t> nextStrides;
|
||||
nextStrides.reserve(subview.getMixedOffsets().size());
|
||||
for (auto [offset, stride, sourceStride] :
|
||||
llvm::zip_equal(subview.getMixedOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||
CompiledIndexExpr offsetValue = makeConstantExpr(0);
|
||||
if (auto attr = mlir::dyn_cast<mlir::Attribute>(offset)) {
|
||||
auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(attr);
|
||||
if (!intAttr)
|
||||
return mlir::failure();
|
||||
offsetValue = makeConstantExpr(intAttr.getInt());
|
||||
}
|
||||
else if (auto value = mlir::dyn_cast<mlir::Value>(offset)) {
|
||||
auto compiledOffset = compileIndexExpr(value);
|
||||
if (failed(compiledOffset))
|
||||
return mlir::failure();
|
||||
offsetValue = *compiledOffset;
|
||||
}
|
||||
else {
|
||||
return mlir::failure();
|
||||
}
|
||||
offsetExpr = addExpr(std::move(offsetExpr), mulExpr(std::move(offsetValue), sourceStride));
|
||||
nextStrides.push_back(stride * sourceStride);
|
||||
}
|
||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||
view.strides = std::move(nextStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return mlir::failure();
|
||||
auto resultType = mlir::cast<mlir::MemRefType>(collapse.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto expand = mlir::dyn_cast<mlir::memref::ExpandShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return mlir::failure();
|
||||
auto resultType = mlir::cast<mlir::MemRefType>(expand.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
}
|
||||
}
|
||||
|
||||
auto resolvedOffset = offsetExpr.evaluate(knowledge);
|
||||
if (failed(resolvedOffset))
|
||||
return mlir::failure();
|
||||
view.offset = *resolvedOffset;
|
||||
return view;
|
||||
}
|
||||
|
||||
if (mlir::isa<mlir::memref::SubViewOp, mlir::memref::CollapseShapeOp, mlir::memref::ExpandShapeOp>(defOp)) {
|
||||
viewOps.push_back(defOp);
|
||||
if (auto subview = mlir::dyn_cast<mlir::memref::SubViewOp>(defOp))
|
||||
current = subview.getSource();
|
||||
else if (auto collapse = mlir::dyn_cast<mlir::memref::CollapseShapeOp>(defOp))
|
||||
current = collapse.getSrc();
|
||||
else
|
||||
current = mlir::cast<mlir::memref::ExpandShapeOp>(defOp).getSrc();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto castOp = mlir::dyn_cast<mlir::memref::CastOp>(defOp)) {
|
||||
current = castOp.getSource();
|
||||
continue;
|
||||
}
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
auto weightIndex = resolveWeightIndex(weightOwner, current);
|
||||
if (!weightIndex)
|
||||
return mlir::failure();
|
||||
|
||||
if (auto coreOp = mlir::dyn_cast_or_null<pim::PimCoreOp>(weightOwner)) {
|
||||
current = coreOp.getWeights()[*weightIndex];
|
||||
continue;
|
||||
}
|
||||
if (auto coreBatchOp = mlir::dyn_cast_or_null<pim::PimCoreBatchOp>(weightOwner)) {
|
||||
current = coreBatchOp.getWeights()[*weightIndex];
|
||||
continue;
|
||||
}
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,15 +1,34 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
inline constexpr llvm::StringRef PimWeightAlwaysAttrName = "weightAlways";
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct ResolvedWeightView {
|
||||
mlir::memref::GlobalOp globalOp;
|
||||
llvm::SmallVector<int64_t> shape;
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
|
||||
bool operator==(const ResolvedWeightView& other) const {
|
||||
return globalOp == other.globalOp && shape == other.shape && strides == other.strides && offset == other.offset;
|
||||
}
|
||||
};
|
||||
|
||||
bool hasWeightAlways(mlir::Operation* op);
|
||||
|
||||
/// Tags an op as producing a value that should stay materialized as a reusable
|
||||
@@ -26,4 +45,20 @@ bool hasOnlySpatialMvmVmmWeightUses(mlir::Value value);
|
||||
/// passes can identify globals that must remain weight-backed.
|
||||
void walkPimMvmVmmWeightUses(mlir::Operation* root, llvm::function_ref<void(mlir::OpOperand&)> callback);
|
||||
|
||||
std::optional<unsigned> resolveWeightIndex(mlir::Operation* weightOwner, mlir::Value weight);
|
||||
llvm::FailureOr<ResolvedWeightView>
|
||||
resolveWeightView(mlir::Operation* weightOwner, mlir::Value weight, const StaticValueKnowledge& knowledge = {});
|
||||
|
||||
template <typename CoreLikeOpTy>
|
||||
llvm::SmallVector<unsigned, 8> getUsedWeightIndices(CoreLikeOpTy coreLikeOp) {
|
||||
llvm::SmallVector<unsigned, 8> indices;
|
||||
coreLikeOp.walk([&](pim::PimVMMOp vmmOp) {
|
||||
auto weightIndex = resolveWeightIndex(coreLikeOp.getOperation(), vmmOp.getWeight());
|
||||
if (weightIndex && !llvm::is_contained(indices, *weightIndex))
|
||||
indices.push_back(*weightIndex);
|
||||
});
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/CoreBlockUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/EntryPointUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
|
||||
@@ -18,7 +18,7 @@ void dumpModule(mlir::ModuleOp moduleOp, const std::string& name) {
|
||||
std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out);
|
||||
llvm::raw_os_ostream os(file);
|
||||
mlir::OpPrintingFlags flags;
|
||||
flags.elideLargeElementsAttrs();
|
||||
flags.elideLargeElementsAttrs().enableDebugInfo(true, false);
|
||||
moduleOp.print(os, flags);
|
||||
os.flush();
|
||||
file.close();
|
||||
|
||||
@@ -13,7 +13,8 @@
|
||||
namespace onnx_mlir::pim {
|
||||
|
||||
struct CappedDiagnosticReporter {
|
||||
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8) : maxReportedFailures(maxReportedFailures) {}
|
||||
explicit CappedDiagnosticReporter(int64_t maxReportedFailures = 8)
|
||||
: maxReportedFailures(maxReportedFailures) {}
|
||||
|
||||
template <typename EmitFn>
|
||||
void report(mlir::Operation* op, EmitFn&& emit) {
|
||||
@@ -24,8 +25,7 @@ struct CappedDiagnosticReporter {
|
||||
|
||||
void emitSuppressedSummary(mlir::Operation* op, llvm::StringRef failureDescription) const {
|
||||
if (numFailures > maxReportedFailures)
|
||||
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional "
|
||||
<< failureDescription;
|
||||
op->emitError() << "suppressed " << (numFailures - maxReportedFailures) << " additional " << failureDescription;
|
||||
}
|
||||
|
||||
bool hasFailure() const { return numFailures != 0; }
|
||||
|
||||
@@ -16,7 +16,6 @@ add_pim_library(OMPimCompilerOptions
|
||||
add_pim_library(OMPimCompilerUtils
|
||||
PimCompilerUtils.cpp
|
||||
PimArtifactWriter.cpp
|
||||
PimBatchEmission.cpp
|
||||
PimCodeGen.cpp
|
||||
PimWeightEmitter.cpp
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ writeMemoryBinary(ModuleOp moduleOp, func::FuncOp funcOp, PimAcceleratorMemory&
|
||||
if (!denseAttr)
|
||||
return;
|
||||
|
||||
MemEntry memEntry = memory.hostMem.getMemEntry(getGlobalOp.getResult());
|
||||
MemEntry memEntry = memory.hostMem.getMemEntry({getGlobalOp.getResult(), std::nullopt});
|
||||
ArrayRef<char> rawData = denseAttr.getRawData();
|
||||
char* dst = memoryBuffer.data() + memEntry.address;
|
||||
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
|
||||
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
|
||||
assert(coreIdsAttr && "pim.core_batch requires coreIds array attribute");
|
||||
return SmallVector<int32_t>(coreIdsAttr.asArrayRef().begin(), coreIdsAttr.asArrayRef().end());
|
||||
}
|
||||
|
||||
static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_t laneCount, unsigned lane) {
|
||||
SmallVector<int32_t> laneCoreIds;
|
||||
laneCoreIds.reserve(coreIds.size() / laneCount);
|
||||
for (size_t chunkIndex = 0; chunkIndex < coreIds.size() / laneCount; ++chunkIndex)
|
||||
laneCoreIds.push_back(coreIds[chunkIndex * laneCount + lane]);
|
||||
return laneCoreIds;
|
||||
}
|
||||
|
||||
static void scalarizeBatchOpsInCore(pim::PimCoreOp scalarCore, size_t laneCount, unsigned lane) {
|
||||
IRRewriter rewriter(scalarCore.getContext());
|
||||
SmallVector<Operation*> batchOps;
|
||||
scalarCore.walk([&](Operation* op) {
|
||||
if (isa<pim::PimSendBatchOp,
|
||||
pim::PimSendTensorBatchOp,
|
||||
pim::PimReceiveBatchOp,
|
||||
pim::PimReceiveTensorBatchOp,
|
||||
pim::PimMemCopyHostToDevBatchOp>(op)) {
|
||||
batchOps.push_back(op);
|
||||
}
|
||||
});
|
||||
|
||||
for (Operation* op : batchOps) {
|
||||
rewriter.setInsertionPoint(op);
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
|
||||
pim::PimSendOp::create(rewriter,
|
||||
sendBatchOp.getLoc(),
|
||||
sendBatchOp.getInput(),
|
||||
sendBatchOp.getSizeAttr(),
|
||||
rewriter.getI32IntegerAttr(sendBatchOp.getTargetCoreIds()[lane]));
|
||||
rewriter.eraseOp(op);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto sendTensorBatchOp = dyn_cast<pim::PimSendTensorBatchOp>(op)) {
|
||||
pim::PimSendTensorOp::create(
|
||||
rewriter,
|
||||
sendTensorBatchOp.getLoc(),
|
||||
sendTensorBatchOp.getInput(),
|
||||
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(sendTensorBatchOp.getTargetCoreIds(), laneCount, lane)));
|
||||
rewriter.eraseOp(op);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
|
||||
auto scalarReceive =
|
||||
pim::PimReceiveOp::create(rewriter,
|
||||
receiveBatchOp.getLoc(),
|
||||
receiveBatchOp.getOutput().getType(),
|
||||
receiveBatchOp.getOutputBuffer(),
|
||||
receiveBatchOp.getSizeAttr(),
|
||||
rewriter.getI32IntegerAttr(receiveBatchOp.getSourceCoreIds()[lane]));
|
||||
rewriter.replaceOp(op, scalarReceive->getResults());
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto receiveTensorBatchOp = dyn_cast<pim::PimReceiveTensorBatchOp>(op)) {
|
||||
auto scalarReceive = pim::PimReceiveTensorOp::create(
|
||||
rewriter,
|
||||
receiveTensorBatchOp.getLoc(),
|
||||
receiveTensorBatchOp.getOutput().getType(),
|
||||
receiveTensorBatchOp.getOutputBuffer(),
|
||||
rewriter.getDenseI32ArrayAttr(getLaneChunkCoreIds(receiveTensorBatchOp.getSourceCoreIds(), laneCount, lane)));
|
||||
rewriter.replaceOp(op, scalarReceive->getResults());
|
||||
continue;
|
||||
}
|
||||
|
||||
auto memcpBatchOp = cast<pim::PimMemCopyHostToDevBatchOp>(op);
|
||||
auto scalarCopy = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
memcpBatchOp.getLoc(),
|
||||
memcpBatchOp.getOutput().getType(),
|
||||
memcpBatchOp.getDeviceTarget(),
|
||||
memcpBatchOp.getHostSource(),
|
||||
memcpBatchOp.getDeviceTargetOffsetAttr(),
|
||||
memcpBatchOp.getHostSourceOffsetAttr(),
|
||||
memcpBatchOp.getSizeAttr());
|
||||
rewriter.replaceOp(op, scalarCopy->getResults());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
llvm::function_ref<LogicalResult(pim::PimCoreOp)> callback) {
|
||||
OwningOpRef<ModuleOp> scratchModule = ModuleOp::create(coreBatchOp.getLoc());
|
||||
OpBuilder builder(scratchModule->getContext());
|
||||
builder.setInsertionPointToStart(scratchModule->getBody());
|
||||
|
||||
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
|
||||
size_t weightsPerLane = coreBatchOp.getWeights().size() / laneCount;
|
||||
SmallVector<Value> laneWeights;
|
||||
laneWeights.reserve(weightsPerLane);
|
||||
for (size_t weightIndex = 0; weightIndex < weightsPerLane; ++weightIndex)
|
||||
laneWeights.push_back(coreBatchOp.getWeights()[lane * weightsPerLane + weightIndex]);
|
||||
|
||||
auto coreIds = getBatchCoreIds(coreBatchOp);
|
||||
auto scalarCore = pim::PimCoreOp::create(
|
||||
builder, coreBatchOp.getLoc(), ValueRange(laneWeights), builder.getI32IntegerAttr(coreIds[lane]));
|
||||
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
|
||||
IRMapping mapper;
|
||||
if (coreBatchOp.getBody().front().getNumArguments() == 1)
|
||||
mapper.map(coreBatchOp.getBody().front().getArgument(0), coreBatchOp.getInputs()[lane]);
|
||||
|
||||
builder.setInsertionPointToEnd(block);
|
||||
for (Operation& op : coreBatchOp.getBody().front()) {
|
||||
Operation* cloned = builder.clone(op, mapper);
|
||||
for (auto [originalResult, clonedResult] : llvm::zip(op.getResults(), cloned->getResults()))
|
||||
mapper.map(originalResult, clonedResult);
|
||||
}
|
||||
|
||||
if (block->empty() || !isa<pim::PimHaltOp>(block->back()))
|
||||
pim::PimHaltOp::create(builder, coreBatchOp.getLoc());
|
||||
scalarizeBatchOpsInCore(scalarCore, laneCount, lane);
|
||||
return callback(scalarCore);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,13 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "llvm/ADT/STLFunctionalExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult withScalarCoreFromBatchLane(pim::PimCoreBatchOp coreBatchOp,
|
||||
unsigned lane,
|
||||
llvm::function_ref<mlir::LogicalResult(pim::PimCoreOp)> callback);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+583
-225
File diff suppressed because it is too large
Load Diff
@@ -4,13 +4,16 @@
|
||||
|
||||
#include "llvm-project/clang/include/clang/Basic/LLVM.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
#include "llvm/Support/raw_os_ostream.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <limits>
|
||||
#include <optional>
|
||||
|
||||
#include "onnx-mlir/Compiler/OMCompilerTypes.h"
|
||||
#include "src/Accelerators/PIM/Common/IR/AddressAnalysis.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBinaryFormat.hpp"
|
||||
@@ -23,6 +26,13 @@ struct MemEntry {
|
||||
size_t size;
|
||||
};
|
||||
|
||||
struct MemoryValueKey {
|
||||
mlir::Value value;
|
||||
std::optional<unsigned> lane;
|
||||
|
||||
bool operator==(const MemoryValueKey& other) const { return value == other.value && lane == other.lane; }
|
||||
};
|
||||
|
||||
struct MemoryReportRow {
|
||||
uint64_t numAlloca = 0;
|
||||
uint64_t sizeAlloca = 0;
|
||||
@@ -50,33 +60,33 @@ struct MemoryReportEntry {
|
||||
};
|
||||
|
||||
class PimMemory {
|
||||
llvm::SmallVector<std::pair<MemEntry, mlir::Value>, 32> memEntries;
|
||||
llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap;
|
||||
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> ownedMemEntriesMap;
|
||||
llvm::SmallVector<std::pair<MemEntry, MemoryValueKey>, 32> memEntries;
|
||||
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap;
|
||||
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> ownedMemEntriesMap;
|
||||
|
||||
size_t minAlignment = 4;
|
||||
size_t firstAvailableAddress = 0;
|
||||
|
||||
MemEntry* gatherMemEntry(mlir::Value value);
|
||||
MemEntry* gatherMemEntry(mlir::Value value, std::optional<unsigned> lane = std::nullopt);
|
||||
void allocateGatheredMemory();
|
||||
void allocateMemoryForValue(mlir::Value value, MemEntry& memEntry);
|
||||
void allocateMemoryForValue(const MemoryValueKey& key, MemEntry& memEntry);
|
||||
|
||||
public:
|
||||
PimMemory(llvm::SmallDenseMap<mlir::Value, MemEntry, 32>& globalMemEntriesMap)
|
||||
PimMemory(llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& globalMemEntriesMap)
|
||||
: globalMemEntriesMap(globalMemEntriesMap) {}
|
||||
|
||||
void allocateHost(mlir::ModuleOp moduleOp, mlir::func::FuncOp funcOp);
|
||||
void allocateCore(mlir::Operation* op);
|
||||
void allocateCore(mlir::Operation* op, std::optional<unsigned> lane = std::nullopt);
|
||||
MemoryReportRow getReportRow() const;
|
||||
void remove(mlir::Value val);
|
||||
|
||||
size_t getFirstAvailableAddress() const { return firstAvailableAddress; }
|
||||
MemEntry getMemEntry(mlir::Value value) const;
|
||||
MemEntry getMemEntry(const MemoryValueKey& key) const;
|
||||
};
|
||||
|
||||
class PimAcceleratorMemory {
|
||||
public:
|
||||
llvm::SmallDenseMap<mlir::Value, MemEntry, 32> memEntriesMap;
|
||||
llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32> memEntriesMap;
|
||||
PimMemory hostMem;
|
||||
|
||||
private:
|
||||
@@ -84,14 +94,23 @@ private:
|
||||
std::fstream fileReport;
|
||||
std::optional<MemoryReportRow> hostReportRow;
|
||||
llvm::SmallVector<MemoryReportEntry, 32> reportEntries;
|
||||
mutable llvm::DenseMap<mlir::Value, CompiledIndexExpr> compiledIndexExprs;
|
||||
mutable llvm::DenseMap<mlir::Value, CompiledAddressExpr> compiledAddressExprs;
|
||||
|
||||
public:
|
||||
PimAcceleratorMemory()
|
||||
: hostMem(memEntriesMap), fileReport(openReportFile("memory_report")) {}
|
||||
PimAcceleratorMemory(const llvm::SmallDenseMap<MemoryValueKey, MemEntry, 32>& initialMemEntries, bool enableReport)
|
||||
: memEntriesMap(initialMemEntries),
|
||||
hostMem(memEntriesMap),
|
||||
fileReport(enableReport ? openReportFile("memory_report") : std::fstream()) {}
|
||||
|
||||
PimMemory& getOrCreateDeviceMem(size_t id);
|
||||
|
||||
size_t getValueAddress(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
|
||||
size_t getValueAddress(mlir::Value value,
|
||||
const StaticValueKnowledge& knowledge = {},
|
||||
std::optional<unsigned> lane = std::nullopt) const;
|
||||
llvm::FailureOr<int64_t> getIndexValue(mlir::Value value, const StaticValueKnowledge& knowledge = {}) const;
|
||||
void reportHost();
|
||||
void recordCoreReport(size_t coreId, const MemoryReportRow& row);
|
||||
void recordBatchReport(uint64_t batchId,
|
||||
@@ -103,15 +122,24 @@ public:
|
||||
void clean(mlir::Operation* op);
|
||||
};
|
||||
|
||||
struct CoreEmissionJob {
|
||||
mlir::Operation* coreLikeOp = nullptr;
|
||||
size_t originalCoreId = 0;
|
||||
size_t emittedCoreId = 0;
|
||||
llvm::SmallVector<unsigned, 4> lanes;
|
||||
std::optional<uint64_t> batchReportId;
|
||||
};
|
||||
|
||||
class PimCodeGen {
|
||||
PimAcceleratorMemory& memory;
|
||||
llvm::raw_fd_ostream& coreBinaryStream;
|
||||
llvm::raw_fd_ostream* coreJsonStream;
|
||||
const llvm::DenseMap<size_t, size_t>& emittedCoreIds;
|
||||
std::optional<unsigned> batchLane;
|
||||
mutable uint32_t emittedInstructionCount = 0;
|
||||
|
||||
size_t addressOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
||||
return memory.getValueAddress(value, knowledge);
|
||||
return memory.getValueAddress(value, knowledge, batchLane);
|
||||
}
|
||||
size_t remapCoreId(size_t coreId) const;
|
||||
|
||||
@@ -141,15 +169,17 @@ public:
|
||||
: memory(memory), coreBinaryStream(coreBinary), coreJsonStream(coreJson), emittedCoreIds(emittedCoreIds) {}
|
||||
|
||||
uint32_t getEmittedInstructionCount() const { return emittedInstructionCount; }
|
||||
void setBatchLane(std::optional<unsigned> lane) { batchLane = lane; }
|
||||
llvm::FailureOr<int64_t> indexOf(mlir::Value value, const StaticValueKnowledge& knowledge) const {
|
||||
return memory.getIndexValue(value, knowledge);
|
||||
}
|
||||
|
||||
void codeGenLoadOp(pim::PimMemCopyHostToDevOp loadOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenStoreOp(pim::PimMemCopyDevToHostOp storeOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenLmvOp(pim::PimMemCopyOp lmvOp, const StaticValueKnowledge& knowledge) const;
|
||||
|
||||
void codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const;
|
||||
void codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKnowledge& knowledge) const;
|
||||
|
||||
template <typename MVMTy>
|
||||
@@ -172,3 +202,20 @@ public:
|
||||
OnnxMlirCompilerErrorCodes compileToPimCode(mlir::ModuleOp& moduleOpRef, std::string& outputDirName);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
namespace llvm {
|
||||
|
||||
template <>
|
||||
struct DenseMapInfo<onnx_mlir::MemoryValueKey> {
|
||||
static onnx_mlir::MemoryValueKey getEmptyKey() { return {DenseMapInfo<mlir::Value>::getEmptyKey(), 0}; }
|
||||
|
||||
static onnx_mlir::MemoryValueKey getTombstoneKey() { return {DenseMapInfo<mlir::Value>::getTombstoneKey(), 0}; }
|
||||
|
||||
static unsigned getHashValue(const onnx_mlir::MemoryValueKey& key) {
|
||||
return hash_combine(key.value, key.lane.value_or(std::numeric_limits<unsigned>::max()));
|
||||
}
|
||||
|
||||
static bool isEqual(const onnx_mlir::MemoryValueKey& lhs, const onnx_mlir::MemoryValueKey& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
} // namespace llvm
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
|
||||
#define DEBUG_TYPE "PimCompilerOptions"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -15,11 +15,10 @@ llvm::cl::opt<PimEmissionTargetType> pimEmissionTarget(
|
||||
llvm::cl::init(EmitPimCodegen),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
llvm::cl::opt<PimMergeSchedulerType> pimMergeScheduler(
|
||||
"pim-merge-scheduler",
|
||||
llvm::cl::opt<PimMergeSchedulerType>
|
||||
pimMergeScheduler("pim-merge-scheduler",
|
||||
llvm::cl::desc("Scheduler used by the Spatial merge-compute-nodes pass"),
|
||||
llvm::cl::values(clEnumValN(MergeSchedulerPeft, "peft", "Use PEFT scheduling")),
|
||||
llvm::cl::values(clEnumValN(MergeSchedulerDcp, "dcp", "Use the legacy DCP-inspired scheduler")),
|
||||
llvm::cl::init(MergeSchedulerPeft),
|
||||
llvm::cl::cat(OnnxMlirOptions));
|
||||
|
||||
@@ -49,12 +48,6 @@ llvm::cl::opt<long> coresCount("core-count",
|
||||
llvm::cl::desc("Number of cores in the chip. Required for PIM compilation."),
|
||||
llvm::cl::init(-1));
|
||||
|
||||
llvm::cl::opt<size_t> dcpCriticalWindowSize(
|
||||
"dcp-critical-window-size",
|
||||
llvm::cl::desc("Number of lowest-slack virtual nodes considered by each DCP coarsening iteration. "
|
||||
"Use 0 to run the legacy full-graph DCP analysis. Only used by the DCP scheduler."),
|
||||
llvm::cl::init(4000));
|
||||
|
||||
llvm::cl::opt<bool>
|
||||
ignoreConcatError("ignore-concat-error",
|
||||
llvm::cl::desc("Ignore ConcatOp corner case: do not assert and do a simplification"),
|
||||
|
||||
@@ -22,7 +22,6 @@ typedef enum {
|
||||
|
||||
typedef enum {
|
||||
MergeSchedulerPeft = 0,
|
||||
MergeSchedulerDcp = 1,
|
||||
} PimMergeSchedulerType;
|
||||
|
||||
extern llvm::cl::OptionCategory OnnxMlirOptions;
|
||||
@@ -36,7 +35,6 @@ extern llvm::cl::opt<bool> pimEmitJson;
|
||||
extern llvm::cl::opt<size_t> crossbarSize;
|
||||
extern llvm::cl::opt<size_t> crossbarCountInCore;
|
||||
extern llvm::cl::opt<long> coresCount;
|
||||
extern llvm::cl::opt<size_t> dcpCriticalWindowSize;
|
||||
|
||||
bool hasExplicitPimCoreCount();
|
||||
void verifyExplicitPimCoreCount();
|
||||
|
||||
@@ -30,20 +30,17 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
if (pimEmissionTarget >= EmitSpatial) {
|
||||
pm.addPass(createONNXToSpatialPass());
|
||||
pm.addPass(createMergeComputeNodesPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Onnx lowered to Spatial"));
|
||||
}
|
||||
|
||||
if (pimEmissionTarget >= EmitPim) {
|
||||
pm.addPass(createSpatialToPimPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Spatial lowered to Pim"));
|
||||
}
|
||||
|
||||
if (pimEmissionTarget >= EmitPimBufferized) {
|
||||
pm.addPass(createPimBufferizationPass());
|
||||
pm.addPass(createPimStaticMemoryCoalescingPass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Pim bufferized"));
|
||||
}
|
||||
|
||||
@@ -54,7 +51,6 @@ void addPassesPim(OwningOpRef<ModuleOp>& module,
|
||||
pm.addPass(createPimVerificationPass());
|
||||
pm.addPass(createMessagePass("Pim verified"));
|
||||
pm.addPass(createEmitPimCodePass());
|
||||
// pm.addPass(createCountInstructionPass());
|
||||
pm.addPass(createMessagePass("Pim code emitted"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
@@ -11,196 +9,42 @@
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimBatchEmission.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCodeGen.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimWeightEmitter.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
namespace {} // namespace
|
||||
|
||||
struct DenseWeightView {
|
||||
DenseElementsAttr denseAttr;
|
||||
SmallVector<int64_t> shape;
|
||||
SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
};
|
||||
|
||||
FailureOr<DenseWeightView> resolveDenseWeightView(ModuleOp moduleOp, mlir::Value weight) {
|
||||
SmallVector<Operation*> viewOps;
|
||||
mlir::Value current = weight;
|
||||
memref::GetGlobalOp getGlobalOp;
|
||||
|
||||
while (true) {
|
||||
Operation* defOp = current.getDefiningOp();
|
||||
if (!defOp)
|
||||
return failure();
|
||||
if ((getGlobalOp = dyn_cast<memref::GetGlobalOp>(defOp)))
|
||||
break;
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(defOp)) {
|
||||
if (!hasAllStaticSubviewParts(subview))
|
||||
return failure();
|
||||
viewOps.push_back(subview);
|
||||
current = subview.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto cast = dyn_cast<memref::CastOp>(defOp)) {
|
||||
current = cast.getSource();
|
||||
continue;
|
||||
}
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(collapse.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(collapse.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
viewOps.push_back(collapse);
|
||||
current = collapse.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(defOp)) {
|
||||
auto srcType = dyn_cast<MemRefType>(expand.getSrc().getType());
|
||||
auto resultType = dyn_cast<MemRefType>(expand.getResult().getType());
|
||||
if (!srcType || !resultType || !srcType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
viewOps.push_back(expand);
|
||||
current = expand.getSrc();
|
||||
continue;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto globalOp = lookupGlobalForGetGlobal(moduleOp, getGlobalOp);
|
||||
if (!globalOp || !globalOp.getInitialValue())
|
||||
return failure();
|
||||
|
||||
auto denseAttr = dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
DenseWeightView view;
|
||||
view.denseAttr = denseAttr;
|
||||
view.shape.assign(denseAttr.getType().getShape().begin(), denseAttr.getType().getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
|
||||
for (Operation* viewOp : llvm::reverse(viewOps)) {
|
||||
if (auto subview = dyn_cast<memref::SubViewOp>(viewOp)) {
|
||||
SmallVector<int64_t> nextStrides;
|
||||
nextStrides.reserve(subview.getStaticStrides().size());
|
||||
for (auto [offset, stride, sourceStride] :
|
||||
llvm::zip_equal(subview.getStaticOffsets(), subview.getStaticStrides(), view.strides)) {
|
||||
view.offset += offset * sourceStride;
|
||||
nextStrides.push_back(stride * sourceStride);
|
||||
}
|
||||
view.shape.assign(subview.getStaticSizes().begin(), subview.getStaticSizes().end());
|
||||
view.strides = std::move(nextStrides);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Collapse/expand are accepted only as contiguous static reshapes of a
|
||||
// dense global view, so a row-major stride recomputation preserves layout.
|
||||
if (auto collapse = dyn_cast<memref::CollapseShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return failure();
|
||||
auto resultType = cast<MemRefType>(collapse.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto expand = dyn_cast<memref::ExpandShapeOp>(viewOp)) {
|
||||
if (view.strides != computeRowMajorStrides(view.shape))
|
||||
return failure();
|
||||
auto resultType = cast<MemRefType>(expand.getResult().getType());
|
||||
view.shape.assign(resultType.getShape().begin(), resultType.getShape().end());
|
||||
view.strides = computeRowMajorStrides(view.shape);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return view;
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 8> getUsedWeightIndices(Block& block) {
|
||||
SmallVector<unsigned, 8> indices;
|
||||
auto addIndex = [&](unsigned weightIndex) {
|
||||
if (!llvm::is_contained(indices, weightIndex))
|
||||
indices.push_back(weightIndex);
|
||||
};
|
||||
|
||||
block.walk([&](pim::PimVMMOp vmmOp) { addIndex(vmmOp.getWeightIndex()); });
|
||||
llvm::sort(indices);
|
||||
return indices;
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 8> getUsedWeightIndices(pim::PimCoreOp coreOp) {
|
||||
return getUsedWeightIndices(coreOp.getBody().front());
|
||||
}
|
||||
|
||||
SmallVector<Operation*> collectTopLevelCoreLikeOps(func::FuncOp funcOp) {
|
||||
SmallVector<Operation*> coreLikeOps;
|
||||
for (Operation& op : funcOp.getBody().front())
|
||||
if (dyn_cast<pim::PimCoreOp>(&op) || dyn_cast<pim::PimCoreBatchOp>(&op))
|
||||
coreLikeOps.push_back(&op);
|
||||
return coreLikeOps;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
|
||||
createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||
ModuleOp moduleOp = funcOp->getParentOfType<ModuleOp>();
|
||||
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>>
|
||||
createAndPopulateWeightFolder(ArrayRef<WeightFileRequest> requests, StringRef outputDirPath) {
|
||||
auto coreWeightsDirPath = outputDirPath + "/weights";
|
||||
auto error = sys::fs::create_directory(coreWeightsDirPath);
|
||||
assert(!error && "Error creating weights directory");
|
||||
size_t indexFileName = 0;
|
||||
|
||||
int64_t xbarSize = crossbarSize.getValue();
|
||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>> mapCoreWeightToFileName;
|
||||
llvm::DenseMap<memref::GlobalOp, std::string> mapGlobalOpToFileName;
|
||||
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>> mapCoreWeightToFileName;
|
||||
llvm::SmallVector<std::pair<ResolvedWeightView, std::string>, 16> materializedWeights;
|
||||
|
||||
SmallVector<Operation*> coreLikeOps = collectTopLevelCoreLikeOps(funcOp);
|
||||
auto materializeWeight = [&](const ResolvedWeightView& weightView) -> std::string {
|
||||
if (auto it = llvm::find_if(materializedWeights, [&](const auto& entry) { return entry.first == weightView; });
|
||||
it != materializedWeights.end())
|
||||
return it->second;
|
||||
|
||||
for (Operation* op : coreLikeOps) {
|
||||
auto processCore = [&](pim::PimCoreOp coreOp) {
|
||||
size_t coreId = static_cast<size_t>(coreOp.getCoreId());
|
||||
for (unsigned index : getUsedWeightIndices(coreOp)) {
|
||||
if (index >= coreOp.getWeights().size()) {
|
||||
coreOp.emitWarning("Weight index " + std::to_string(index) + " is out of range");
|
||||
assert(index < coreOp.getWeights().size() && "Weight index is out of range");
|
||||
}
|
||||
mlir::Value weight = coreOp.getWeights()[index];
|
||||
auto globalOp = weightView.globalOp;
|
||||
auto denseAttr = mlir::dyn_cast<DenseElementsAttr>(*globalOp.getInitialValue());
|
||||
assert(denseAttr && "Weight global must have dense initial value");
|
||||
|
||||
auto weightView = resolveDenseWeightView(moduleOp, weight);
|
||||
if (failed(weightView)) {
|
||||
coreOp.emitWarning("Weight is not from a memref.get_global at index " + std::to_string(index));
|
||||
assert(succeeded(weightView) && "Weight is not from a dense memref.global view");
|
||||
}
|
||||
|
||||
if (mapCoreWeightToFileName[coreId].contains(weight))
|
||||
continue;
|
||||
|
||||
auto getGlobalOp = weight.getDefiningOp<memref::GetGlobalOp>();
|
||||
auto globalOp = getGlobalOp ? lookupGlobalForGetGlobal(moduleOp, getGlobalOp) : memref::GlobalOp {};
|
||||
if (globalOp && mapGlobalOpToFileName.contains(globalOp)) {
|
||||
auto& fileName = mapGlobalOpToFileName[globalOp];
|
||||
mapCoreWeightToFileName[coreId].insert({weight, fileName});
|
||||
continue;
|
||||
}
|
||||
|
||||
DenseElementsAttr denseAttr = weightView->denseAttr;
|
||||
ArrayRef<int64_t> shape = weightView->shape;
|
||||
ArrayRef<int64_t> shape = weightView.shape;
|
||||
assert(isMatrixShape(shape) && "Weight matrix must be 2-dimensional");
|
||||
int64_t numRows = shape[0];
|
||||
int64_t numCols = shape[1];
|
||||
assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size");
|
||||
|
||||
size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType());
|
||||
|
||||
std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin";
|
||||
auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str();
|
||||
@@ -215,7 +59,7 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||
for (int64_t row = 0; row < xbarSize; row++) {
|
||||
for (int64_t col = 0; col < xbarSize; col++) {
|
||||
if (row < numRows && col < numCols) {
|
||||
int64_t elementIndex = weightView->offset + row * weightView->strides[0] + col * weightView->strides[1];
|
||||
int64_t elementIndex = weightView.offset + row * weightView.strides[0] + col * weightView.strides[1];
|
||||
APInt bits = denseAttr.getValues<APFloat>()[elementIndex].bitcastToAPInt();
|
||||
uint64_t word = bits.getZExtValue();
|
||||
weightFileStream.write(reinterpret_cast<const char*>(&word), elementByteWidth);
|
||||
@@ -227,23 +71,17 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) {
|
||||
}
|
||||
|
||||
weightFileStream.close();
|
||||
if (globalOp)
|
||||
mapGlobalOpToFileName.insert({globalOp, newFileName});
|
||||
mapCoreWeightToFileName[coreId].insert({weight, newFileName});
|
||||
}
|
||||
return success();
|
||||
materializedWeights.push_back({weightView, newFileName});
|
||||
return newFileName;
|
||||
};
|
||||
|
||||
if (auto coreOp = dyn_cast<pim::PimCoreOp>(op)) {
|
||||
(void) processCore(coreOp);
|
||||
continue;
|
||||
for (const WeightFileRequest& request : requests) {
|
||||
auto& coreFiles = mapCoreWeightToFileName[request.coreId];
|
||||
coreFiles.reserve(request.weights.size());
|
||||
for (const ResolvedWeightView& weight : request.weights)
|
||||
coreFiles.push_back(materializeWeight(weight));
|
||||
}
|
||||
|
||||
auto coreBatchOp = cast<pim::PimCoreBatchOp>(op);
|
||||
for (unsigned lane = 0; lane < static_cast<unsigned>(coreBatchOp.getLaneCount()); ++lane)
|
||||
if (failed(withScalarCoreFromBatchLane(coreBatchOp, lane, processCore)))
|
||||
return mapCoreWeightToFileName;
|
||||
}
|
||||
return mapCoreWeightToFileName;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
llvm::DenseMap<size_t, llvm::DenseMap<mlir::Value, std::string>>
|
||||
createAndPopulateWeightFolder(mlir::func::FuncOp funcOp, llvm::StringRef outputDirPath);
|
||||
struct WeightFileRequest {
|
||||
size_t coreId = 0;
|
||||
llvm::SmallVector<ResolvedWeightView, 8> weights;
|
||||
};
|
||||
|
||||
llvm::DenseMap<size_t, llvm::SmallVector<std::string, 8>>
|
||||
createAndPopulateWeightFolder(llvm::ArrayRef<WeightFileRequest> requests, llvm::StringRef outputDirPath);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -3,11 +3,12 @@ mlir_tablegen(ONNXToSpatial.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(ONNXToSpatialIncGen)
|
||||
|
||||
add_pim_library(OMONNXToSpatial
|
||||
ConversionPatterns.cpp
|
||||
HostFoldability.cpp
|
||||
HostLegality.cpp
|
||||
PrePatterns.cpp
|
||||
PostPatterns.cpp
|
||||
Patterns.cpp
|
||||
CompileTime.cpp
|
||||
ONNXToSpatialVerifier.cpp
|
||||
Patterns/Pre.cpp
|
||||
Patterns/Post.cpp
|
||||
Patterns/GeneratedConversion.cpp
|
||||
Patterns/Math/Conv.cpp
|
||||
Patterns/Math/Elementwise.cpp
|
||||
Patterns/Math/Gemm.cpp
|
||||
@@ -22,8 +23,11 @@ add_pim_library(OMONNXToSpatial
|
||||
Patterns/Tensor/Resize.cpp
|
||||
Patterns/Tensor/Reshape.cpp
|
||||
Patterns/Tensor/Split.cpp
|
||||
Patterns/Tensor/Transpose.cpp
|
||||
ONNXToSpatialPass.cpp
|
||||
Common/AttributeUtils.cpp
|
||||
Common/ComputeRegionBuilder.cpp
|
||||
Common/IndexingUtils.cpp
|
||||
Common/ShapeTilingUtils.cpp
|
||||
Common/WeightMaterialization.cpp
|
||||
|
||||
@@ -33,6 +37,7 @@ add_pim_library(OMONNXToSpatial
|
||||
ONNXToSpatialIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRLinalgDialect
|
||||
MLIRSCFDialect
|
||||
MLIRTosaDialect
|
||||
OMCompilerOptions
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
|
||||
#include "AttributeUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
int64_t getI64Attr(ArrayAttr attr, size_t index) { return cast<IntegerAttr>(attr[index]).getInt(); }
|
||||
|
||||
int64_t getOptionalI64Attr(std::optional<ArrayAttr> attr, size_t index, int64_t defaultValue) {
|
||||
return attr ? getI64Attr(*attr, index) : defaultValue;
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t> getI64ArrayAttrValues(ArrayAttr attr) {
|
||||
llvm::SmallVector<int64_t> values;
|
||||
values.reserve(attr.size());
|
||||
for (Attribute value : attr)
|
||||
values.push_back(cast<IntegerAttr>(value).getInt());
|
||||
return values;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
int64_t getI64Attr(mlir::ArrayAttr attr, size_t index);
|
||||
|
||||
int64_t getOptionalI64Attr(std::optional<mlir::ArrayAttr> attr, size_t index, int64_t defaultValue);
|
||||
|
||||
llvm::SmallVector<int64_t> getI64ArrayAttrValues(mlir::ArrayAttr attr);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,6 +1,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "AttributeUtils.hpp"
|
||||
#include "ComputeRegionBuilder.hpp"
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
@@ -7,9 +8,11 @@
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
@@ -18,13 +21,17 @@ namespace detail {
|
||||
|
||||
inline mlir::ValueRange getBlockArgs(mlir::Block* block) { return mlir::ValueRange(block->getArguments()); }
|
||||
|
||||
inline mlir::ValueRange getInputBlockArgs(mlir::Block* block, size_t weightCount) {
|
||||
return mlir::ValueRange(block->getArguments()).drop_front(weightCount);
|
||||
}
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
decltype(auto) invokeWithBlockArgs(Fn&& fn, mlir::Block* block, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(block->getArgument(Is)...);
|
||||
}
|
||||
|
||||
template <typename Fn, size_t... Is>
|
||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ArrayRef<mlir::Value> values, std::index_sequence<Is...>) {
|
||||
decltype(auto) invokeWithValues(Fn&& fn, mlir::ValueRange values, std::index_sequence<Is...>) {
|
||||
return std::forward<Fn>(fn)(values[Is]...);
|
||||
}
|
||||
|
||||
@@ -45,6 +52,13 @@ using InvokeWithBlockArgsResultT = typename InvokeWithBlockArgsResult<Fn, Seq>::
|
||||
template <typename Fn>
|
||||
using InvokeWithValueRangeResultT = std::invoke_result_t<Fn, mlir::ValueRange>;
|
||||
|
||||
struct SpatComputeBatchBodyArgs {
|
||||
mlir::Value lane;
|
||||
mlir::ValueRange weights;
|
||||
mlir::ValueRange inputs;
|
||||
mlir::ValueRange outputs;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename RewriterT>
|
||||
@@ -85,6 +99,8 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value weight : weights)
|
||||
block->addArgument(weight.getType(), loc);
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
@@ -93,14 +109,17 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
|
||||
using BodyResult = detail::InvokeWithBlockArgsResultT<std::decay_t<BodyFn>, std::make_index_sequence<NumInputs>>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
detail::invokeWithValues(std::forward<BodyFn>(body),
|
||||
detail::getInputBlockArgs(block, weights.size()),
|
||||
std::make_index_sequence<NumInputs> {});
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
else {
|
||||
auto bodyResult =
|
||||
detail::invokeWithBlockArgs(std::forward<BodyFn>(body), block, std::make_index_sequence<NumInputs> {});
|
||||
auto bodyResult = detail::invokeWithValues(std::forward<BodyFn>(body),
|
||||
detail::getInputBlockArgs(block, weights.size()),
|
||||
std::make_index_sequence<NumInputs> {});
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
@@ -123,6 +142,8 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
auto computeOp = spatial::SpatCompute::create(rewriter, loc, resultTypes, weights, inputs);
|
||||
|
||||
auto* block = new mlir::Block();
|
||||
for (mlir::Value weight : weights)
|
||||
block->addArgument(weight.getType(), loc);
|
||||
for (mlir::Value input : inputs)
|
||||
block->addArgument(input.getType(), loc);
|
||||
|
||||
@@ -131,13 +152,13 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
|
||||
using BodyResult = detail::InvokeWithValueRangeResultT<std::decay_t<BodyFn>>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||
std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
||||
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
return computeOp;
|
||||
}
|
||||
else {
|
||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getBlockArgs(block));
|
||||
auto bodyResult = std::forward<BodyFn>(body)(detail::getInputBlockArgs(block, weights.size()));
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
rewriter.eraseOp(computeOp);
|
||||
@@ -148,6 +169,95 @@ auto createSpatCompute(RewriterT& rewriter,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RewriterT, typename BodyFn>
|
||||
auto createSpatComputeBatch(RewriterT& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::TypeRange resultTypes,
|
||||
int64_t laneCount,
|
||||
mlir::ValueRange weights,
|
||||
mlir::ValueRange inputs,
|
||||
BodyFn&& body) {
|
||||
if (laneCount <= 0 || laneCount > std::numeric_limits<int32_t>::max())
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
||||
|
||||
auto batchOp = spatial::SpatComputeBatch::create(
|
||||
rewriter, loc, resultTypes, rewriter.getI32IntegerAttr(static_cast<int32_t>(laneCount)), weights, inputs);
|
||||
|
||||
mlir::SmallVector<mlir::Type> blockArgTypes {rewriter.getIndexType()};
|
||||
mlir::SmallVector<mlir::Location> blockArgLocs {loc};
|
||||
blockArgTypes.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
|
||||
blockArgLocs.reserve(1 + weights.size() + inputs.size() + resultTypes.size());
|
||||
for (mlir::Value weight : weights) {
|
||||
blockArgTypes.push_back(weight.getType());
|
||||
blockArgLocs.push_back(weight.getLoc());
|
||||
}
|
||||
for (mlir::Value input : inputs) {
|
||||
blockArgTypes.push_back(input.getType());
|
||||
blockArgLocs.push_back(input.getLoc());
|
||||
}
|
||||
for (mlir::Type resultType : resultTypes) {
|
||||
blockArgTypes.push_back(resultType);
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
|
||||
auto* block =
|
||||
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), mlir::TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
|
||||
detail::SpatComputeBatchBodyArgs args {
|
||||
block->getArgument(0),
|
||||
mlir::ValueRange(block->getArguments()).slice(1, weights.size()),
|
||||
mlir::ValueRange(block->getArguments()).slice(1 + weights.size(), inputs.size()),
|
||||
mlir::ValueRange(block->getArguments()).drop_front(1 + weights.size() + inputs.size())};
|
||||
|
||||
using BodyResult = std::invoke_result_t<BodyFn, detail::SpatComputeBatchBodyArgs>;
|
||||
if constexpr (std::is_same_v<BodyResult, void>) {
|
||||
std::forward<BodyFn>(body)(args);
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
|
||||
}
|
||||
else {
|
||||
auto bodyResult = std::forward<BodyFn>(body)(args);
|
||||
if (mlir::failed(bodyResult)) {
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
rewriter.eraseOp(batchOp);
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(mlir::failure());
|
||||
}
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
return mlir::FailureOr<spatial::SpatComputeBatch>(batchOp);
|
||||
}
|
||||
}
|
||||
|
||||
inline void createParallelInsertSliceIntoBatchOutput(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
mlir::Value dest,
|
||||
mlir::ArrayRef<mlir::OpFoldResult> offsets,
|
||||
mlir::ArrayRef<mlir::OpFoldResult> sizes,
|
||||
mlir::ArrayRef<mlir::OpFoldResult> strides) {
|
||||
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
mlir::tensor::ParallelInsertSliceOp::create(rewriter, loc, source, dest, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
template <typename BodyFn>
|
||||
mlir::Value materializeOrComputeUnary(mlir::Value input,
|
||||
mlir::RankedTensorType resultType,
|
||||
mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
BodyFn&& build) {
|
||||
auto&& buildFn = build;
|
||||
if (isCompileTimeComputable(input))
|
||||
return buildFn(input);
|
||||
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, loc, mlir::TypeRange {resultType}, {}, mlir::ValueRange {input}, [&](mlir::Value computeInput) {
|
||||
mlir::Value result = buildFn(computeInput);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
mlir::Value sumTensors(mlir::ArrayRef<mlir::Value> tensors, mlir::ConversionPatternRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
#include <algorithm>
|
||||
|
||||
#include "IndexingUtils.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
||||
|
||||
FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank) {
|
||||
int64_t normalizedAxis = normalizeAxis(axis, rank);
|
||||
if (normalizedAxis < 0 || normalizedAxis >= rank)
|
||||
return failure();
|
||||
return normalizedAxis;
|
||||
}
|
||||
|
||||
int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
|
||||
|
||||
static SmallVector<int64_t> normalizeAxesImpl(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
||||
SmallVector<int64_t> normalizedAxes;
|
||||
if (!axesAttr) {
|
||||
normalizedAxes.reserve(rank);
|
||||
for (int64_t axis = 0; axis < rank; ++axis)
|
||||
normalizedAxes.push_back(axis);
|
||||
}
|
||||
else {
|
||||
normalizedAxes.reserve(axesAttr->size());
|
||||
for (Attribute attr : *axesAttr)
|
||||
normalizedAxes.push_back(normalizeAxis(cast<IntegerAttr>(attr).getInt(), rank));
|
||||
llvm::sort(normalizedAxes);
|
||||
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
|
||||
}
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> normalizeAxesChecked(std::optional<ArrayAttr> axesAttr, int64_t rank) {
|
||||
SmallVector<int64_t> normalizedAxes = normalizeAxesImpl(axesAttr, rank);
|
||||
for (int64_t axis : normalizedAxes)
|
||||
if (axis < 0 || axis >= rank)
|
||||
return failure();
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,20 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
int64_t normalizeAxis(int64_t axis, int64_t rank);
|
||||
|
||||
mlir::FailureOr<int64_t> normalizeAxisChecked(int64_t axis, int64_t rank);
|
||||
|
||||
int64_t normalizeIndex(int64_t index, int64_t dimSize);
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int64_t>> normalizeAxesChecked(std::optional<mlir::ArrayAttr> axesAttr, int64_t rank);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -3,26 +3,103 @@
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "IndexingUtils.hpp"
|
||||
#include "ShapeTilingUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool hasStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(RankedTensorType type) {
|
||||
return type.hasStaticShape() && hasStaticPositiveShape(type.getShape());
|
||||
}
|
||||
|
||||
int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> permutedShape;
|
||||
permutedShape.reserve(permutation.size());
|
||||
for (int64_t axis : permutation)
|
||||
permutedShape.push_back(shape[axis]);
|
||||
return permutedShape;
|
||||
}
|
||||
|
||||
SmallVector<int64_t> invertPermutation(ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> inversePermutation(permutation.size());
|
||||
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
|
||||
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
|
||||
return inversePermutation;
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<int64_t>> getTransposePermutationChecked(std::optional<ArrayAttr> permAttr, int64_t rank) {
|
||||
SmallVector<int64_t> permutation;
|
||||
if (!permAttr) {
|
||||
permutation.reserve(rank);
|
||||
for (int64_t dim = rank - 1; dim >= 0; --dim)
|
||||
permutation.push_back(dim);
|
||||
return permutation;
|
||||
}
|
||||
|
||||
if (static_cast<int64_t>(permAttr->size()) != rank)
|
||||
return failure();
|
||||
|
||||
permutation.reserve(permAttr->size());
|
||||
SmallVector<bool> seen(rank, false);
|
||||
for (IntegerAttr attr : permAttr->getAsRange<IntegerAttr>()) {
|
||||
int64_t axis = attr.getInt();
|
||||
if (axis < 0 || axis >= rank || seen[axis])
|
||||
return failure();
|
||||
seen[axis] = true;
|
||||
permutation.push_back(axis);
|
||||
}
|
||||
return permutation;
|
||||
}
|
||||
|
||||
Value transposeMaybeInCompute(
|
||||
Value value, RankedTensorType resultType, ArrayRef<int64_t> permutation, PatternRewriter& rewriter, Location loc) {
|
||||
auto buildTranspose = [&](Value input) -> Value {
|
||||
return ONNXTransposeOp::create(rewriter, loc, resultType, input, rewriter.getI64ArrayAttr(permutation)).getResult();
|
||||
};
|
||||
return materializeOrComputeUnary(value, resultType, rewriter, loc, buildTranspose);
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getUnitStrides(PatternRewriter& rewriter, int64_t rank) {
|
||||
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getZeroOffsets(PatternRewriter& rewriter, int64_t rank) {
|
||||
return SmallVector<OpFoldResult>(rank, rewriter.getIndexAttr(0));
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> getStaticSizes(PatternRewriter& rewriter, ArrayRef<int64_t> shape) {
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(shape.size());
|
||||
for (int64_t dim : shape)
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
return sizes;
|
||||
}
|
||||
|
||||
SmallVector<Value> sliceTensor(
|
||||
const Value& tensorToSlice, size_t axis, int64_t sliceSize, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
ArrayRef<long> shape = getTensorShape(tensorToSlice);
|
||||
assert("Invalid axis" && axis < shape.size());
|
||||
|
||||
SmallVector<OpFoldResult> strides(shape.size(), rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> offsets(shape.size(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(shape.size());
|
||||
for (const auto size : shape)
|
||||
sizes.push_back(rewriter.getIndexAttr(size));
|
||||
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, shape.size());
|
||||
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, shape);
|
||||
sizes[axis] = rewriter.getIndexAttr(sliceSize);
|
||||
|
||||
long length = shape[axis];
|
||||
@@ -44,7 +121,7 @@ SmallVector<Value> sliceTensor(
|
||||
RankedTensorType::get(sliceShape, cast<RankedTensorType>(tensorToSlice.getType()).getElementType());
|
||||
|
||||
Value slice;
|
||||
if (isHostFoldableValue(tensorToSlice)) {
|
||||
if (isCompileTimeComputable(tensorToSlice)) {
|
||||
slice = tensor::ExtractSliceOp::create(rewriter, loc, tensorToSlice, offsets, sizes, strides);
|
||||
}
|
||||
else {
|
||||
@@ -80,47 +157,33 @@ sliceVectorPerCrossbarPerCore(const Value& vectorToSlice, ConversionPatternRewri
|
||||
return slicesPerCore;
|
||||
}
|
||||
|
||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tileMatrix(
|
||||
Value& matrixToTile, int64_t hSliceSize, int64_t vSliceSize, ConversionPatternRewriter& rewriter, Location& loc) {
|
||||
assert("Not a matrix" && isMatrixShape(getTensorShape(matrixToTile)));
|
||||
Value extractAxisSlice(
|
||||
PatternRewriter& rewriter, Location loc, Value source, int64_t axis, int64_t offset, int64_t size) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
SmallVector<int64_t> resultShape(sourceType.getShape());
|
||||
resultShape[axis] = size;
|
||||
auto resultType = RankedTensorType::get(resultShape, sourceType.getElementType(), sourceType.getEncoding());
|
||||
|
||||
DenseMap<HSliceId, DenseMap<CoreId, SmallVector<Value>>> tiles;
|
||||
|
||||
SmallVector<Value> hSlices = sliceTensor(matrixToTile, 1, hSliceSize, rewriter, loc);
|
||||
size_t numHSlices = hSlices.size();
|
||||
for (size_t hSliceId = 0; hSliceId < numHSlices; hSliceId++) {
|
||||
Value hSlice = hSlices[hSliceId];
|
||||
SmallVector<Value> vSlices = sliceTensor(hSlice, 0, vSliceSize, rewriter, loc);
|
||||
for (size_t vSliceId = 0; vSliceId < vSlices.size(); vSliceId++) {
|
||||
size_t coreId = vSliceId / crossbarCountInCore;
|
||||
Value vSlice = vSlices[vSliceId];
|
||||
tiles[hSliceId][coreId].push_back(vSlice);
|
||||
}
|
||||
}
|
||||
return tiles;
|
||||
SmallVector<OpFoldResult> offsets = getZeroOffsets(rewriter, sourceType.getRank());
|
||||
SmallVector<OpFoldResult> sizes = getStaticSizes(rewriter, sourceType.getShape());
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(size);
|
||||
return tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, resultType, source, offsets, sizes, getUnitStrides(rewriter, sourceType.getRank()))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
Value broadcastToVector(Value scalarToBroadcast, int64_t length, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto oldType = cast<RankedTensorType>(scalarToBroadcast.getType());
|
||||
Type elementType = oldType.getElementType();
|
||||
int64_t shape[2] = {1, length};
|
||||
Type type = oldType.cloneWith(ArrayRef(shape), elementType);
|
||||
|
||||
auto buildBroadcast = [&](Value input) -> Value {
|
||||
auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0).getResult();
|
||||
SmallVector<Value> index(oldType.getRank(), zero);
|
||||
auto elementValue = tensor::ExtractOp::create(rewriter, loc, input, index).getResult();
|
||||
return tensor::SplatOp::create(rewriter, loc, type, elementValue);
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(scalarToBroadcast))
|
||||
return buildBroadcast(scalarToBroadcast);
|
||||
|
||||
auto broadcastCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {type}, {}, ValueRange {scalarToBroadcast}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildBroadcast(input));
|
||||
});
|
||||
return broadcastCompute.getResult(0);
|
||||
Value insertStaticSlice(
|
||||
PatternRewriter& rewriter, Location loc, Value source, Value dest, ArrayRef<OpFoldResult> offsets) {
|
||||
auto sourceType = cast<RankedTensorType>(source.getType());
|
||||
return tensor::InsertSliceOp::create(rewriter,
|
||||
loc,
|
||||
source,
|
||||
dest,
|
||||
offsets,
|
||||
getStaticSizes(rewriter, sourceType.getShape()),
|
||||
getUnitStrides(rewriter, sourceType.getRank()))
|
||||
.getResult();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
@@ -11,46 +12,12 @@
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageWidth(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(2);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageHeight(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(3);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageChannel(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(1);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getImageN(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(0);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getKernelWidth(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(2);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getKernelHeight(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(3);
|
||||
}
|
||||
|
||||
template <class ShapedType>
|
||||
inline auto getFilterCount(const ShapedType& shapedType) {
|
||||
return shapedType.getDimSize(0);
|
||||
}
|
||||
|
||||
using HSliceId = size_t;
|
||||
using CoreId = size_t;
|
||||
|
||||
@@ -87,17 +54,6 @@ bool isHVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && shape[0] == 1;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
bool isVVectorShape(mlir::ArrayRef<T> shape) {
|
||||
return shape.size() == 2 && shape[1] == 1;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
T getVectorLength(mlir::ArrayRef<T> shape) {
|
||||
assert(isVectorShape(shape));
|
||||
return shape[0] != 1 ? shape[0] : shape[1];
|
||||
}
|
||||
|
||||
inline auto getTensorShape(mlir::Value tensor) {
|
||||
return mlir::cast<mlir::RankedTensorType>(tensor.getType()).getShape();
|
||||
}
|
||||
@@ -109,6 +65,31 @@ inline bool haveSameStaticShape(mlir::Value lhs, mlir::Value rhs) {
|
||||
&& lhsType.getShape() == rhsType.getShape();
|
||||
}
|
||||
|
||||
bool hasStaticPositiveShape(mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
bool hasStaticPositiveShape(mlir::RankedTensorType type);
|
||||
|
||||
int64_t getStaticShapeElementCount(mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
llvm::SmallVector<int64_t> permuteShape(mlir::ArrayRef<int64_t> shape, mlir::ArrayRef<int64_t> permutation);
|
||||
|
||||
llvm::SmallVector<int64_t> invertPermutation(mlir::ArrayRef<int64_t> permutation);
|
||||
|
||||
mlir::FailureOr<llvm::SmallVector<int64_t>> getTransposePermutationChecked(std::optional<mlir::ArrayAttr> permAttr,
|
||||
int64_t rank);
|
||||
|
||||
mlir::Value transposeMaybeInCompute(mlir::Value value,
|
||||
mlir::RankedTensorType resultType,
|
||||
mlir::ArrayRef<int64_t> permutation,
|
||||
mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getUnitStrides(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getZeroOffsets(mlir::PatternRewriter& rewriter, int64_t rank);
|
||||
|
||||
llvm::SmallVector<mlir::OpFoldResult> getStaticSizes(mlir::PatternRewriter& rewriter, mlir::ArrayRef<int64_t> shape);
|
||||
|
||||
/// Slices a statically shaped tensor along one axis into contiguous pieces of
|
||||
/// at most `sliceSize` elements.
|
||||
llvm::SmallVector<mlir::Value> sliceTensor(const mlir::Value& tensorToSlice,
|
||||
@@ -127,18 +108,13 @@ llvm::SmallVector<mlir::Value> sliceVector(const mlir::Value& vectorToSlice,
|
||||
llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>> sliceVectorPerCrossbarPerCore(
|
||||
const mlir::Value& vectorToSlice, mlir::ConversionPatternRewriter& rewriter, mlir::Location loc);
|
||||
|
||||
/// Tiles a matrix first across output columns and then across input rows so it
|
||||
/// can be assigned to crossbars grouped by core.
|
||||
llvm::DenseMap<HSliceId, llvm::DenseMap<CoreId, llvm::SmallVector<mlir::Value>>>
|
||||
tileMatrix(mlir::Value& matrixToTile,
|
||||
int64_t hSliceSize,
|
||||
int64_t vSliceSize,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location& loc);
|
||||
mlir::Value extractAxisSlice(
|
||||
mlir::PatternRewriter& rewriter, mlir::Location loc, mlir::Value source, int64_t axis, int64_t offset, int64_t size);
|
||||
|
||||
mlir::Value broadcastToVector(mlir::Value scalarToBroadcast,
|
||||
int64_t length,
|
||||
mlir::ConversionPatternRewriter& rewriter,
|
||||
mlir::Location loc);
|
||||
mlir::Value insertStaticSlice(mlir::PatternRewriter& rewriter,
|
||||
mlir::Location loc,
|
||||
mlir::Value source,
|
||||
mlir::Value dest,
|
||||
llvm::ArrayRef<mlir::OpFoldResult> offsets);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
@@ -43,8 +44,8 @@ bool isWeightLikeComputeOperand(Value value) {
|
||||
value = collapseShapeOp.getSrc();
|
||||
continue;
|
||||
}
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||
value = transposeOp.getData();
|
||||
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
|
||||
value = transposeOp.getInput();
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -80,7 +81,7 @@ FailureOr<Value> materializeWeightLikeValueInBlock(Value value, IRRewriter& rewr
|
||||
return referencedValue.getResult();
|
||||
}
|
||||
|
||||
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(definingOp))
|
||||
if (!isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(definingOp))
|
||||
return failure();
|
||||
|
||||
IRMapping localMapper;
|
||||
|
||||
+79
-36
@@ -1,4 +1,5 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
@@ -8,7 +9,11 @@
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include <utility>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ConstantUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -23,8 +28,7 @@ static bool hasStaticUnitStrides(tensor::ExtractSliceOp extractSliceOp) {
|
||||
}
|
||||
|
||||
static bool hasConstantIndices(tensor::ExtractOp extractOp) {
|
||||
return llvm::all_of(extractOp.getIndices(),
|
||||
[](Value index) { return isa_and_nonnull<arith::ConstantIndexOp>(index.getDefiningOp()); });
|
||||
return llvm::all_of(extractOp.getIndices(), [](Value index) { return matchConstantIndexValue(index).has_value(); });
|
||||
}
|
||||
|
||||
static bool isStaticTensorResult(Operation* op) {
|
||||
@@ -34,13 +38,6 @@ static bool isStaticTensorResult(Operation* op) {
|
||||
});
|
||||
}
|
||||
|
||||
static SmallVector<int64_t> computeRowMajorStrides(ArrayRef<int64_t> shape) {
|
||||
SmallVector<int64_t> strides(shape.size(), 1);
|
||||
for (int64_t dim = static_cast<int64_t>(shape.size()) - 2; dim >= 0; --dim)
|
||||
strides[dim] = strides[dim + 1] * shape[dim + 1];
|
||||
return strides;
|
||||
}
|
||||
|
||||
static FailureOr<DenseElementsAttr> transposeDenseElements(DenseElementsAttr denseAttr, ArrayRef<int64_t> perms) {
|
||||
auto tensorType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!tensorType)
|
||||
@@ -145,7 +142,7 @@ static DenseElementsAttr getDirectDenseConstantAttr(Value value) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
static DenseElementsAttr getHostConstantDenseElementsAttrImpl(Value value, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp || !visited.insert(definingOp).second)
|
||||
return nullptr;
|
||||
@@ -156,7 +153,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
|
||||
return denseAttr;
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(transposeOp.getData(), visited);
|
||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getData(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
|
||||
@@ -168,8 +165,18 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
|
||||
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(definingOp)) {
|
||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(transposeOp.getInput(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
|
||||
SmallVector<int64_t> perm(transposeOp.getPermutation().begin(), transposeOp.getPermutation().end());
|
||||
auto transposedAttr = transposeDenseElements(inputAttr, perm);
|
||||
return succeeded(transposedAttr) ? *transposedAttr : nullptr;
|
||||
}
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(collapseShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(collapseShapeOp.getType()));
|
||||
@@ -177,7 +184,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
|
||||
}
|
||||
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
|
||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(expandShapeOp.getSrc(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto reshapedAttr = reshapeDenseElements(inputAttr, cast<RankedTensorType>(expandShapeOp.getType()));
|
||||
@@ -185,7 +192,7 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
|
||||
}
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(definingOp)) {
|
||||
auto inputAttr = getHostFoldableDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
|
||||
auto inputAttr = getHostConstantDenseElementsAttrImpl(extractSliceOp.getSource(), visited);
|
||||
if (!inputAttr)
|
||||
return nullptr;
|
||||
auto slicedAttr = extractSliceDenseElements(inputAttr, extractSliceOp);
|
||||
@@ -195,62 +202,98 @@ static DenseElementsAttr getHostFoldableDenseElementsAttrImpl(Value value, llvm:
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static bool isHostFoldableOpImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited) {
|
||||
if (!op || !visited.insert(op).second)
|
||||
return false;
|
||||
static std::optional<CompileTimeSource>
|
||||
getCompileTimeSourceImpl(Operation* op, llvm::SmallPtrSetImpl<Operation*>& visited, size_t chainLength = 0) {
|
||||
if (!op)
|
||||
return std::nullopt;
|
||||
|
||||
if (!visited.insert(op).second)
|
||||
return {
|
||||
{op, chainLength}
|
||||
};
|
||||
|
||||
if (isa<arith::ConstantOp, ONNXConstantOp, ONNXNoneOp>(op))
|
||||
return true;
|
||||
return {
|
||||
{op, chainLength}
|
||||
};
|
||||
|
||||
chainLength += 1;
|
||||
|
||||
if (auto extractOp = dyn_cast<tensor::ExtractOp>(op))
|
||||
return hasConstantIndices(extractOp) && isHostFoldableValue(extractOp.getTensor());
|
||||
return hasConstantIndices(extractOp)
|
||||
? getCompileTimeSourceImpl(extractOp.getTensor().getDefiningOp(), visited, chainLength)
|
||||
: std::nullopt;
|
||||
|
||||
if (!isStaticTensorResult(op))
|
||||
return false;
|
||||
return std::nullopt;
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op))
|
||||
return isHostFoldableValue(transposeOp.getData());
|
||||
return getCompileTimeSourceImpl(transposeOp.getData().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(op))
|
||||
return getCompileTimeSourceImpl(transposeOp.getInput().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto collapseShapeOp = dyn_cast<tensor::CollapseShapeOp>(op))
|
||||
return isHostFoldableValue(collapseShapeOp.getSrc());
|
||||
return getCompileTimeSourceImpl(collapseShapeOp.getSrc().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op))
|
||||
return isHostFoldableValue(expandShapeOp.getSrc());
|
||||
return getCompileTimeSourceImpl(expandShapeOp.getSrc().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(op))
|
||||
return hasStaticUnitStrides(extractSliceOp) && isHostFoldableValue(extractSliceOp.getSource());
|
||||
return hasStaticUnitStrides(extractSliceOp)
|
||||
? getCompileTimeSourceImpl(extractSliceOp.getSource().getDefiningOp(), visited, chainLength)
|
||||
: std::nullopt;
|
||||
|
||||
if (auto splatOp = dyn_cast<tensor::SplatOp>(op))
|
||||
return isHostFoldableValue(splatOp.getInput());
|
||||
return getCompileTimeSourceImpl(splatOp.getInput().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto extractRowsOp = dyn_cast<spatial::SpatExtractRowsOp>(op))
|
||||
return isHostFoldableValue(extractRowsOp.getInput());
|
||||
return getCompileTimeSourceImpl(extractRowsOp.getInput().getDefiningOp(), visited, chainLength);
|
||||
|
||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op))
|
||||
return llvm::all_of(concatOp.getInputs(), isHostFoldableValue);
|
||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
||||
std::optional<CompileTimeSource> res = {};
|
||||
for (auto operandValue : concatOp.getOperands()) {
|
||||
auto partialRes = getCompileTimeSourceImpl(operandValue.getDefiningOp(), visited, chainLength);
|
||||
if (!partialRes)
|
||||
return std::nullopt;
|
||||
|
||||
return false;
|
||||
if (!res) {
|
||||
res = partialRes;
|
||||
continue;
|
||||
}
|
||||
if (res->chainLength < partialRes->chainLength)
|
||||
res = partialRes;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool isHostFoldableValue(Value value) {
|
||||
std::optional<CompileTimeSource> getCompileTimeSource(Operation* op) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return getCompileTimeSourceImpl(op, visited);
|
||||
}
|
||||
|
||||
bool isCompileTimeComputable(Value value) {
|
||||
auto* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return false;
|
||||
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return isHostFoldableOpImpl(definingOp, visited);
|
||||
return getCompileTimeSourceImpl(definingOp, visited).has_value();
|
||||
}
|
||||
|
||||
bool isHostFoldableOp(Operation* op) {
|
||||
bool isCompileTimeOp(Operation* op) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return isHostFoldableOpImpl(op, visited);
|
||||
return getCompileTimeSourceImpl(op, visited).has_value();
|
||||
}
|
||||
|
||||
DenseElementsAttr getHostFoldableDenseElementsAttr(Value value) {
|
||||
DenseElementsAttr getHostConstDenseElementsAttr(Value value) {
|
||||
llvm::SmallPtrSet<Operation*, 8> visited;
|
||||
return getHostFoldableDenseElementsAttrImpl(value, visited);
|
||||
return getHostConstantDenseElementsAttrImpl(value, visited);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,22 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct CompileTimeSource {
|
||||
mlir::Operation* source;
|
||||
size_t chainLength;
|
||||
};
|
||||
|
||||
std::optional<CompileTimeSource> getCompileTimeSource(mlir::Operation* op);
|
||||
|
||||
bool isCompileTimeComputable(mlir::Value value);
|
||||
|
||||
bool isCompileTimeOp(mlir::Operation* op);
|
||||
|
||||
mlir::DenseElementsAttr getHostConstDenseElementsAttr(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,15 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool isHostFoldableValue(mlir::Value value);
|
||||
|
||||
bool isHostFoldableOp(mlir::Operation* op);
|
||||
|
||||
mlir::DenseElementsAttr getHostFoldableDenseElementsAttr(mlir::Value value);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,34 +0,0 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
LogicalResult verifyONNXToSpatialHostLegality(func::FuncOp funcOp) {
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
for (Operation& op : funcOp.getFunctionBody().front()) {
|
||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||
continue;
|
||||
if (isHostFoldableOp(&op))
|
||||
continue;
|
||||
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError("non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside "
|
||||
"spat.compute");
|
||||
});
|
||||
}
|
||||
|
||||
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial host legality failures");
|
||||
|
||||
return success(!diagnostics.hasFailure());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,10 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult verifyONNXToSpatialHostLegality(mlir::func::FuncOp funcOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,5 +1,7 @@
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
@@ -12,13 +14,10 @@
|
||||
|
||||
#include "Common/Common.hpp"
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostLegality.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Compiler/CompilerOptions.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -44,7 +43,8 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
IRMapping mapper;
|
||||
SmallVector<spatial::SpatCompute> computes(funcOp.getOps<spatial::SpatCompute>());
|
||||
if (!computes.empty())
|
||||
SmallVector<spatial::SpatComputeBatch> computeBatches(funcOp.getOps<spatial::SpatComputeBatch>());
|
||||
if (!computes.empty() || !computeBatches.empty())
|
||||
return;
|
||||
|
||||
auto returnOp = cast<func::ReturnOp>(funcOp.getFunctionBody().front().getTerminator());
|
||||
@@ -85,30 +85,6 @@ static void populateEmptyFunction(func::FuncOp funcOp) {
|
||||
returnOp.setOperand(index, computeResult);
|
||||
}
|
||||
|
||||
static void wrapTopLevelRuntimeTransposes(func::FuncOp funcOp) {
|
||||
IRRewriter rewriter(funcOp.getContext());
|
||||
Block& entryBlock = funcOp.getFunctionBody().front();
|
||||
|
||||
for (Operation& op : llvm::make_early_inc_range(entryBlock)) {
|
||||
auto transposeOp = dyn_cast<ONNXTransposeOp>(&op);
|
||||
if (!transposeOp || isHostFoldableOp(transposeOp))
|
||||
continue;
|
||||
|
||||
// Transpose stays globally legal because constant/view-only cases are
|
||||
// allowed on the host. Any residual runtime transpose must be sunk into
|
||||
// spat.compute before the host legality check.
|
||||
auto resultType = transposeOp.getResult().getType();
|
||||
rewriter.setInsertionPoint(transposeOp);
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, transposeOp.getLoc(), TypeRange {resultType}, {}, ValueRange {transposeOp.getData()}, [&](Value input) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, transposeOp.getLoc(), resultType, input, transposeOp.getPermAttr());
|
||||
spatial::SpatYieldOp::create(rewriter, transposeOp.getLoc(), transposed);
|
||||
});
|
||||
rewriter.replaceOp(transposeOp, computeOp.getResult(0));
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXToSpatialPass::runOnOperation() {
|
||||
ModuleOp moduleOp = getOperation();
|
||||
MLIRContext* ctx = &getContext();
|
||||
@@ -116,7 +92,9 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
ConversionTarget preTarget(*ctx);
|
||||
preTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
linalg::LinalgDialect,
|
||||
tensor::TensorDialect,
|
||||
affine::AffineDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
preTarget.addIllegalOp<ONNXConstantOp, ONNXFlattenOp>();
|
||||
@@ -154,10 +132,13 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
linalg::LinalgDialect,
|
||||
tensor::TensorDialect,
|
||||
affine::AffineDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
target.addIllegalOp<ONNXMatMulOp>();
|
||||
target.addIllegalOp<ONNXTransposeOp>();
|
||||
target.addIllegalOp<ONNXAddOp>();
|
||||
target.addIllegalOp<ONNXDivOp>();
|
||||
target.addIllegalOp<ONNXMulOp>();
|
||||
@@ -184,22 +165,18 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
RewritePatternSet transposePatterns(ctx);
|
||||
populateTransposePatterns(transposePatterns, ctx);
|
||||
walkAndApplyPatterns(moduleOp, std::move(transposePatterns));
|
||||
|
||||
ConversionTarget earlyPostTarget(*ctx);
|
||||
earlyPostTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
linalg::LinalgDialect,
|
||||
tensor::TensorDialect,
|
||||
affine::AffineDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
earlyPostTarget.addDynamicallyLegalOp<spatial::SpatComputeBatch>(
|
||||
[](spatial::SpatComputeBatch batchOp) { return !requiresEarlyPostRewrite(batchOp); });
|
||||
|
||||
RewritePatternSet earlyPostPatterns(ctx);
|
||||
populateEarlyPostPatterns(earlyPostPatterns, ctx);
|
||||
if (failed(applyPartialConversion(*entryFunc, earlyPostTarget, std::move(earlyPostPatterns)))) {
|
||||
moduleOp.emitError("failed to normalize single-lane spat.compute_batch ops before core assignment checks");
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
||||
PassManager cleanupPM(ctx);
|
||||
cleanupPM.addPass(createCanonicalizerPass());
|
||||
@@ -211,7 +188,9 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
ConversionTarget postTarget(*ctx);
|
||||
postTarget.addLegalDialect<spatial::SpatialDialect,
|
||||
ONNXDialect,
|
||||
linalg::LinalgDialect,
|
||||
tensor::TensorDialect,
|
||||
affine::AffineDialect,
|
||||
arith::ArithDialect,
|
||||
scf::SCFDialect>();
|
||||
postTarget.addDynamicallyLegalOp<spatial::SpatCompute>(
|
||||
@@ -227,9 +206,7 @@ void ONNXToSpatialPass::runOnOperation() {
|
||||
return;
|
||||
}
|
||||
|
||||
wrapTopLevelRuntimeTransposes(*entryFunc);
|
||||
|
||||
if (failed(verifyONNXToSpatialHostLegality(*entryFunc))) {
|
||||
if (failed(verifyONNXToSpatial(*entryFunc))) {
|
||||
moduleOp.emitError("ONNX-to-Spatial host legality verification failed");
|
||||
signalPassFailure();
|
||||
return;
|
||||
|
||||
@@ -0,0 +1,157 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
#include "Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatialVerifier.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
void checkWeightUseChains(func::FuncOp func, pim::CappedDiagnosticReporter& diagnostics) {
|
||||
func.walk([&](Operation* op) {
|
||||
if (!hasWeightAlways(op))
|
||||
return;
|
||||
|
||||
for (Value result : op->getResults()) {
|
||||
if (hasOnlySpatialMvmVmmWeightUses(result))
|
||||
continue;
|
||||
|
||||
diagnostics.report(op, [&](Operation* illegalOp) {
|
||||
illegalOp->emitOpError(
|
||||
"weight-marked values may only flow through static view/slice helper chains into Spatial VMM weights");
|
||||
});
|
||||
return;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Region* getParentRegion(Value value) {
|
||||
if (auto blockArg = dyn_cast<BlockArgument>(value))
|
||||
return blockArg.getOwner()->getParent();
|
||||
if (Operation* definingOp = value.getDefiningOp())
|
||||
return definingOp->getParentRegion();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool isDefinedInsideRegion(Value value, Region& region) {
|
||||
Region* parentRegion = getParentRegion(value);
|
||||
return parentRegion && (®ion == parentRegion || region.isAncestor(parentRegion));
|
||||
}
|
||||
|
||||
bool isLegalHostBackedValue(Value value) {
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return isa<BlockArgument>(value);
|
||||
|
||||
if (isa<spatial::SpatChannelReceiveOp>(definingOp))
|
||||
return false;
|
||||
|
||||
return definingOp->getDialect()->getNamespace() != "spat";
|
||||
}
|
||||
|
||||
LogicalResult verifyComputeLikeInputs(Operation* computeLikeOp,
|
||||
ValueRange inputs,
|
||||
bool allowChannelReceiveInputs,
|
||||
StringRef kind,
|
||||
pim::CappedDiagnosticReporter& diagnostics) {
|
||||
for (auto [inputIndex, input] : llvm::enumerate(inputs)) {
|
||||
unsigned currentInputIndex = inputIndex;
|
||||
Operation* definingOp = input.getDefiningOp();
|
||||
if (allowChannelReceiveInputs && isa_and_nonnull<spatial::SpatChannelReceiveOp>(definingOp))
|
||||
continue;
|
||||
if (isLegalHostBackedValue(input))
|
||||
continue;
|
||||
|
||||
diagnostics.report(computeLikeOp, [&](Operation* illegalOp) {
|
||||
InFlightDiagnostic diagnostic = illegalOp->emitOpError()
|
||||
<< kind << " input #" << currentInputIndex
|
||||
<< (allowChannelReceiveInputs ? " must come from the host or an explicit "
|
||||
"spat.channel_receive"
|
||||
: " must come from the host");
|
||||
if (definingOp)
|
||||
diagnostic.attachNote(definingOp->getLoc()) << "illegal Spatial producer is " << definingOp->getName();
|
||||
});
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
void verifyNoExternalTensorCaptures(Operation* ownerOp,
|
||||
Region& region,
|
||||
StringRef kind,
|
||||
pim::CappedDiagnosticReporter& diagnostics) {
|
||||
region.walk([&](Operation* op) {
|
||||
for (OpOperand& operand : op->getOpOperands()) {
|
||||
Value value = operand.get();
|
||||
if (!isa<TensorType>(value.getType()))
|
||||
continue;
|
||||
if (isDefinedInsideRegion(value, region) || isa<BlockArgument>(value))
|
||||
continue;
|
||||
|
||||
Operation* definingOp = value.getDefiningOp();
|
||||
if (definingOp && definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
continue;
|
||||
|
||||
diagnostics.report(ownerOp, [&](Operation* illegalOp) {
|
||||
InFlightDiagnostic diagnostic = illegalOp->emitOpError() << kind << " body may not capture external tensor "
|
||||
<< "values";
|
||||
diagnostic.attachNote(op->getLoc())
|
||||
<< "tensor operand #" << operand.getOperandNumber() << " is defined outside the compute body by "
|
||||
<< (definingOp ? definingOp->getName().getStringRef() : StringRef("<block argument>"));
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult verifyONNXToSpatial(func::FuncOp funcOp) {
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
for (Operation& op : funcOp.getOps()) {
|
||||
if (isa<func::ReturnOp, spatial::SpatCompute, spatial::SpatComputeBatch>(&op))
|
||||
continue;
|
||||
if (isCompileTimeOp(&op))
|
||||
continue;
|
||||
|
||||
diagnostics.report(&op, [](Operation* illegalOp) {
|
||||
illegalOp->emitOpError(
|
||||
"non-foldable top-level runtime op remains after ONNX-to-Spatial; lower it inside spat.compute");
|
||||
});
|
||||
}
|
||||
checkWeightUseChains(funcOp, diagnostics);
|
||||
diagnostics.emitSuppressedSummary(funcOp, "ONNX-to-Spatial verification failed");
|
||||
|
||||
return success(!diagnostics.hasFailure());
|
||||
}
|
||||
|
||||
LogicalResult verifySpatialCommunicationInvariants(func::FuncOp funcOp) {
|
||||
pim::CappedDiagnosticReporter diagnostics;
|
||||
|
||||
for (auto computeOp : funcOp.getOps<spatial::SpatCompute>()) {
|
||||
(void) verifyComputeLikeInputs(
|
||||
computeOp.getOperation(), computeOp.getInputs(), /*allowChannelReceiveInputs=*/true, "spat.compute", diagnostics);
|
||||
verifyNoExternalTensorCaptures(computeOp.getOperation(), computeOp.getBody(), "spat.compute", diagnostics);
|
||||
}
|
||||
|
||||
for (auto computeBatchOp : funcOp.getOps<spatial::SpatComputeBatch>()) {
|
||||
(void) verifyComputeLikeInputs(computeBatchOp.getOperation(),
|
||||
computeBatchOp.getInputs(),
|
||||
/*allowChannelReceiveInputs=*/false,
|
||||
"spat.compute_batch",
|
||||
diagnostics);
|
||||
verifyNoExternalTensorCaptures(
|
||||
computeBatchOp.getOperation(), computeBatchOp.getBody(), "spat.compute_batch", diagnostics);
|
||||
}
|
||||
|
||||
diagnostics.emitSuppressedSummary(funcOp, "Spatial communication invariant verification failed");
|
||||
return success(!diagnostics.hasFailure());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult verifyONNXToSpatial(mlir::func::FuncOp funcOp);
|
||||
mlir::LogicalResult verifySpatialCommunicationInvariants(mlir::func::FuncOp funcOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+9
-9
@@ -1,19 +1,14 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
|
||||
patterns.add<removeLRN>(ctx);
|
||||
void populatePrePatterns(RewritePatternSet& patterns, MLIRContext* ctx) { populateGeneratedPrePatterns(patterns, ctx); }
|
||||
|
||||
void populateConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
populateGeneratedConversionPatterns(patterns, ctx);
|
||||
populateElementwisePatterns(patterns, ctx);
|
||||
populateGemmPatterns(patterns, ctx);
|
||||
populateConvPatterns(patterns, ctx);
|
||||
@@ -27,6 +22,11 @@ void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRCon
|
||||
populateResizePatterns(patterns, ctx);
|
||||
populateReshapePatterns(patterns, ctx);
|
||||
populateSplitPatterns(patterns, ctx);
|
||||
populateTransposePatterns(patterns, ctx);
|
||||
}
|
||||
|
||||
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
populateWeightPromotionPatterns(patterns, ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+14
-13
@@ -1,38 +1,39 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateGeneratedConversionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateWeightPromotionPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateConvPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateElementwisePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateGemmPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateMatMulRewritePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePoolPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateReduceMeanPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateReluPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateSigmoidPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateSoftmaxPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateConcatPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateGatherPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateResizePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateReshapePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populateSplitPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
void populateTransposePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp);
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
|
||||
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -0,0 +1,18 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ONNXToSpatial.hpp.inc"
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateGeneratedConversionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<removeLRN>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "src/Accelerators/PIM/Common/Support/Diagnostics.hpp"
|
||||
#include "src/Accelerators/PIM/Compiler/PimCompilerOptions.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -28,8 +28,6 @@ struct ConvToGemm : OpConversionPattern<ONNXConvOp> {
|
||||
ConversionPatternRewriter& rewriter) const override;
|
||||
};
|
||||
|
||||
static int64_t getI64FromArrayAttr(ArrayAttr arr, size_t idx) { return cast<IntegerAttr>(arr[idx]).getInt(); }
|
||||
|
||||
static Value expandBiasIfNeeded(Value bias, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto biasType = cast<RankedTensorType>(bias.getType());
|
||||
if (biasType.getRank() != 1)
|
||||
@@ -53,23 +51,100 @@ static Value createPaddedRows(Value tensorValue,
|
||||
if (tensorType.getDimSize(0) == paddedRows)
|
||||
return tensorValue;
|
||||
|
||||
auto paddedType = RankedTensorType::get({paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType());
|
||||
auto paddedType = RankedTensorType::get(
|
||||
{paddedRows, tensorType.getDimSize(1)}, tensorType.getElementType(), tensorType.getEncoding());
|
||||
SmallVector<OpFoldResult> lowPads = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> highPads = {rewriter.getIndexAttr(paddedRows - tensorType.getDimSize(0)),
|
||||
rewriter.getIndexAttr(0)};
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, paddedType, tensorValue, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int i = 0; i < 2; i++)
|
||||
for (int i = 0; i < 2; ++i)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(
|
||||
rewriter, loc, tensorType.getElementType(), rewriter.getZeroAttr(tensorType.getElementType()));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(tensorType.getElementType()), tensorType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
static Value packRowsForParallelGemm(
|
||||
Value rows, RankedTensorType rowsType, int64_t packFactor, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (packFactor == 1)
|
||||
return rows;
|
||||
|
||||
const int64_t packedNumRows = ceilIntegerDivide(rowsType.getDimSize(0), packFactor);
|
||||
const int64_t paddedNumRows = packedNumRows * packFactor;
|
||||
const int64_t rowWidth = rowsType.getDimSize(1);
|
||||
auto groupedType =
|
||||
RankedTensorType::get({packedNumRows, packFactor, rowWidth}, rowsType.getElementType(), rowsType.getEncoding());
|
||||
auto packedType =
|
||||
RankedTensorType::get({packedNumRows, packFactor * rowWidth}, rowsType.getElementType(), rowsType.getEncoding());
|
||||
|
||||
Value paddedRows = createPaddedRows(rows, rowsType, paddedNumRows, rewriter, loc);
|
||||
Value groupedRows = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
groupedType,
|
||||
paddedRows,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
packedType,
|
||||
groupedRows,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
}
|
||||
|
||||
static Value unpackRowsFromParallelGemm(Value packedRows,
|
||||
RankedTensorType packedRowsType,
|
||||
int64_t unpackedRows,
|
||||
int64_t rowWidth,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
if (packFactor == 1)
|
||||
return packedRows;
|
||||
|
||||
const int64_t packedNumRows = packedRowsType.getDimSize(0);
|
||||
const int64_t paddedNumRows = packedNumRows * packFactor;
|
||||
auto expandedType = RankedTensorType::get(
|
||||
{packedNumRows, packFactor, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
||||
auto paddedType =
|
||||
RankedTensorType::get({paddedNumRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
||||
auto unpackedType =
|
||||
RankedTensorType::get({unpackedRows, rowWidth}, packedRowsType.getElementType(), packedRowsType.getEncoding());
|
||||
|
||||
Value expandedRows = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
packedRows,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
Value paddedRows = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
paddedType,
|
||||
expandedRows,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
if (paddedNumRows == unpackedRows)
|
||||
return paddedRows;
|
||||
|
||||
SmallVector<OpFoldResult> offsets {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(unpackedRows), rewriter.getIndexAttr(rowWidth)};
|
||||
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, unpackedType, paddedRows, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
|
||||
Value wTrans,
|
||||
RankedTensorType wType,
|
||||
@@ -108,7 +183,30 @@ static Value buildPackedWeight(DenseElementsAttr wDenseAttr,
|
||||
}
|
||||
|
||||
auto packedAttr = DenseElementsAttr::get(packedWeightType, packedValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, packedWeightType, packedAttr);
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedAttr, packedWeightType);
|
||||
}
|
||||
|
||||
static Value createConvWeightMatrix(
|
||||
Value w, RankedTensorType wFlatType, RankedTensorType wTransType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto buildWeightMatrix = [&](Value weight) -> Value {
|
||||
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
wFlatType,
|
||||
weight,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
return ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0})).getResult();
|
||||
};
|
||||
|
||||
if (isCompileTimeComputable(w))
|
||||
return buildWeightMatrix(w);
|
||||
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {wTransType}, {}, ValueRange {w}, [&](Value weight) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildWeightMatrix(weight));
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value buildPackedBias(bool hasBias,
|
||||
@@ -134,7 +232,7 @@ static Value buildPackedBias(bool hasBias,
|
||||
|
||||
auto packedBiasType = RankedTensorType::get({1, packFactor * numChannelsOut}, outType.getElementType());
|
||||
auto packedBiasAttr = DenseElementsAttr::get(packedBiasType, packedValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, packedBiasType, packedBiasAttr).getResult();
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), packedBiasAttr, packedBiasType);
|
||||
}
|
||||
|
||||
static Value createIm2colRowComputes(Value x,
|
||||
@@ -165,7 +263,6 @@ static Value createIm2colRowComputes(Value x,
|
||||
Location loc) {
|
||||
auto elemType = xType.getElementType();
|
||||
constexpr size_t numInputs = 1;
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
auto im2colComputeOp =
|
||||
createSpatCompute<numInputs>(rewriter, loc, TypeRange {gemmInputRowsType}, {}, x, [&](Value xArg) {
|
||||
Value paddedInput = xArg;
|
||||
@@ -190,8 +287,8 @@ static Value createIm2colRowComputes(Value x,
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getFloatAttr(elemType, 0.0));
|
||||
tensor::YieldOp::create(rewriter, loc, zero.getResult());
|
||||
auto zero = getOrCreateConstant(rewriter, padOp.getOperation(), rewriter.getFloatAttr(elemType, 0.0), elemType);
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
paddedInput = padOp.getResult();
|
||||
}
|
||||
@@ -199,13 +296,14 @@ static Value createIm2colRowComputes(Value x,
|
||||
// Build im2col [numPatches, patchSize] incrementally to keep the IR small
|
||||
// until the late PIM unrolling step.
|
||||
Value im2colInit = tensor::EmptyOp::create(rewriter, loc, im2colType.getShape(), elemType);
|
||||
auto c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
auto cNumPatches = arith::ConstantIndexOp::create(rewriter, loc, numPatches);
|
||||
auto cNumPatchesPerBatch = arith::ConstantIndexOp::create(rewriter, loc, numPatchesPerBatch);
|
||||
auto cOutWidth = arith::ConstantIndexOp::create(rewriter, loc, outWidth);
|
||||
auto cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
||||
auto cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
auto c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||
auto c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
||||
auto cNumPatches = getOrCreateIndexConstant(rewriter, anchorOp, numPatches);
|
||||
auto cNumPatchesPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, numPatchesPerBatch);
|
||||
auto cOutWidth = getOrCreateIndexConstant(rewriter, anchorOp, outWidth);
|
||||
auto cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight);
|
||||
auto cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth);
|
||||
|
||||
auto im2colLoop = scf::ForOp::create(rewriter, loc, c0, cNumPatches, c1, ValueRange {im2colInit});
|
||||
rewriter.setInsertionPointToStart(im2colLoop.getBody());
|
||||
@@ -252,28 +350,8 @@ static Value createIm2colRowComputes(Value x,
|
||||
Value im2col = im2colLoop.getResult(0);
|
||||
|
||||
Value gemmInputRows = im2col;
|
||||
if (packFactor != 1) {
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto groupedType = RankedTensorType::get({packedNumRows, packFactor, patchSize}, elemType);
|
||||
auto packedType = RankedTensorType::get({packedNumRows, packFactor * patchSize}, elemType);
|
||||
Value paddedIm2col = createPaddedRows(im2col, im2colType, paddedNumPatches, rewriter, loc);
|
||||
Value groupedIm2col = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
groupedType,
|
||||
paddedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
gemmInputRows = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
packedType,
|
||||
groupedIm2col,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
}
|
||||
if (packFactor != 1)
|
||||
gemmInputRows = packRowsForParallelGemm(im2col, im2colType, packFactor, rewriter, loc);
|
||||
|
||||
spatial::SpatYieldOp::create(rewriter, loc, gemmInputRows);
|
||||
});
|
||||
@@ -291,41 +369,20 @@ static Value createCollectedConvOutput(ValueRange gemmRows,
|
||||
int64_t packFactor,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t packedNumRows = ceilIntegerDivide(numPatches, packFactor);
|
||||
const int64_t paddedNumPatches = packedNumRows * packFactor;
|
||||
auto collectComputeOp = createSpatCompute(rewriter, loc, convType, {}, gemmRows, [&](ValueRange gemmRowArgs) {
|
||||
Value gemmOut;
|
||||
if (packFactor == 1) {
|
||||
gemmOut = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||
}
|
||||
else {
|
||||
auto expandedType = RankedTensorType::get({packedNumRows, packFactor, numChannelsOut}, outType.getElementType());
|
||||
auto paddedType = RankedTensorType::get({paddedNumPatches, numChannelsOut}, outType.getElementType());
|
||||
Value packedOutput = createSpatConcat(rewriter, loc, /*axis=*/0, gemmRowArgs);
|
||||
Value expandedOutput = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
expandedType,
|
||||
packedOutput,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
Value paddedOutput = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
paddedType,
|
||||
expandedOutput,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
|
||||
gemmOut = paddedOutput;
|
||||
if (paddedNumPatches != numPatches) {
|
||||
SmallVector<OpFoldResult> offsets = {rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(numPatches), rewriter.getIndexAttr(numChannelsOut)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
gemmOut = tensor::ExtractSliceOp::create(rewriter, loc, gemmOutType, paddedOutput, offsets, sizes, strides);
|
||||
}
|
||||
gemmOut = unpackRowsFromParallelGemm(packedOutput,
|
||||
cast<RankedTensorType>(packedOutput.getType()),
|
||||
numPatches,
|
||||
numChannelsOut,
|
||||
packFactor,
|
||||
rewriter,
|
||||
loc);
|
||||
}
|
||||
|
||||
// Restore to NCHW layout:
|
||||
@@ -391,19 +448,11 @@ static Value lowerSingleConvGroup(Value x,
|
||||
const int64_t xbarSize = static_cast<int64_t>(crossbarSize.getValue());
|
||||
const int64_t wMaxDim = std::max(patchSize, numChannelsOut);
|
||||
const int64_t maxParallelPixels = std::max<int64_t>(1, xbarSize / wMaxDim);
|
||||
auto wDenseAttr = getHostFoldableDenseElementsAttr(w);
|
||||
auto wDenseAttr = getHostConstDenseElementsAttr(w);
|
||||
|
||||
// Prepare weight matrix W for crossbar storage:
|
||||
// W: [numChannelsOut, numChannelsIn, wHeight, wWidth] -> [numChannelsOut, patchSize] -> [patchSize, numChannelsOut]
|
||||
Value wFlat = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
wFlatType,
|
||||
w,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2, 3}
|
||||
});
|
||||
Value wTrans = ONNXTransposeOp::create(rewriter, loc, wTransType, wFlat, rewriter.getI64ArrayAttr({1, 0}));
|
||||
Value wTrans = createConvWeightMatrix(w, wFlatType, wTransType, rewriter, loc);
|
||||
|
||||
// Pass bias through directly; Gemm handles rank-1 C canonicalization.
|
||||
bool hasB = !isa<ONNXNoneOp>(b.getDefiningOp());
|
||||
@@ -412,7 +461,7 @@ static Value lowerSingleConvGroup(Value x,
|
||||
DenseElementsAttr biasDenseAttr;
|
||||
if (hasB) {
|
||||
gemmBias = b;
|
||||
biasDenseAttr = getHostFoldableDenseElementsAttr(b);
|
||||
biasDenseAttr = getHostConstDenseElementsAttr(b);
|
||||
biasMatrix = expandBiasIfNeeded(b, rewriter, loc);
|
||||
}
|
||||
const bool canPackWeightsAsConstants = static_cast<bool>(wDenseAttr);
|
||||
@@ -597,10 +646,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int64_t strideHeight = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 0) : 1;
|
||||
const int64_t strideWidth = stridesAttr ? getI64FromArrayAttr(*stridesAttr, 1) : 1;
|
||||
const int64_t dilationHeight = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 0) : 1;
|
||||
const int64_t dilationWidth = dilationsAttr ? getI64FromArrayAttr(*dilationsAttr, 1) : 1;
|
||||
const int64_t strideHeight = getOptionalI64Attr(stridesAttr, 0, 1);
|
||||
const int64_t strideWidth = getOptionalI64Attr(stridesAttr, 1, 1);
|
||||
const int64_t dilationHeight = getOptionalI64Attr(dilationsAttr, 0, 1);
|
||||
const int64_t dilationWidth = getOptionalI64Attr(dilationsAttr, 1, 1);
|
||||
|
||||
int64_t padHeightBegin = 0;
|
||||
int64_t padHeightEnd = 0;
|
||||
@@ -608,10 +657,10 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
int64_t padWidthEnd = 0;
|
||||
|
||||
if (padsAttr) {
|
||||
padHeightBegin = getI64FromArrayAttr(*padsAttr, 0);
|
||||
padWidthBegin = getI64FromArrayAttr(*padsAttr, 1);
|
||||
padHeightEnd = getI64FromArrayAttr(*padsAttr, 2);
|
||||
padWidthEnd = getI64FromArrayAttr(*padsAttr, 3);
|
||||
padHeightBegin = getI64Attr(*padsAttr, 0);
|
||||
padWidthBegin = getI64Attr(*padsAttr, 1);
|
||||
padHeightEnd = getI64Attr(*padsAttr, 2);
|
||||
padWidthEnd = getI64Attr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
// Compute padding from auto_pad attribute
|
||||
@@ -717,7 +766,7 @@ LogicalResult ConvToGemm::matchAndRewrite(ONNXConvOp convOp,
|
||||
}
|
||||
|
||||
Value result;
|
||||
if (llvm::all_of(groupResults, isHostFoldableValue)) {
|
||||
if (llvm::all_of(groupResults, isCompileTimeComputable)) {
|
||||
result = createSpatConcat(rewriter, loc, /*axis=*/1, groupResults);
|
||||
}
|
||||
else {
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -83,7 +83,7 @@ static FailureOr<Value> materializeBroadcastedConstantTensor(Value value,
|
||||
}
|
||||
|
||||
auto broadcastedAttr = DenseElementsAttr::get(resultType, resultValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, resultType, broadcastedAttr).getResult();
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), broadcastedAttr, resultType);
|
||||
}
|
||||
|
||||
static FailureOr<Value>
|
||||
@@ -121,7 +121,7 @@ static FailureOr<Value> materializeReciprocalTensor(Value value,
|
||||
}
|
||||
|
||||
auto reciprocalAttr = DenseFPElementsAttr::get(resultType, reciprocalValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, resultType, reciprocalAttr).getResult();
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), reciprocalAttr, resultType);
|
||||
}
|
||||
|
||||
template <typename OnnxOp, typename SpatialOp>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,5 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
@@ -5,12 +7,10 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -19,14 +19,6 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
static int64_t getStaticShapeElementCount(ArrayRef<int64_t> shape) {
|
||||
return std::accumulate(shape.begin(), shape.end(), int64_t {1}, std::multiplies<int64_t> {});
|
||||
}
|
||||
|
||||
static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t> lhsBatchShape,
|
||||
ArrayRef<int64_t> rhsBatchShape) {
|
||||
if (lhsBatchShape.empty())
|
||||
@@ -38,62 +30,60 @@ static FailureOr<SmallVector<int64_t>> inferSupportedBatchShape(ArrayRef<int64_t
|
||||
return SmallVector<int64_t>(lhsBatchShape.begin(), lhsBatchShape.end());
|
||||
}
|
||||
|
||||
static Value collapseBatchDims(Value value,
|
||||
int64_t batchSize,
|
||||
int64_t rows,
|
||||
int64_t cols,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static Value
|
||||
collapseBatchDims(Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
if (type.getRank() == 2 || type.getRank() == 3)
|
||||
return value;
|
||||
|
||||
auto collapsedType =
|
||||
RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||
SmallVector<ReassociationIndices> reassociation = {
|
||||
ReassociationIndices {},
|
||||
auto collapsedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 2)},
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}
|
||||
};
|
||||
ReassociationIndices {static_cast<int64_t>(type.getRank() - 1)}};
|
||||
for (int64_t dim = 0; dim < type.getRank() - 2; ++dim)
|
||||
reassociation.front().push_back(dim);
|
||||
|
||||
auto buildCollapsed = [&](Value input) -> Value {
|
||||
return tensor::CollapseShapeOp::create(rewriter, loc, collapsedType, input, reassociation);
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
return buildCollapsed(value);
|
||||
|
||||
auto collapseCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {collapsedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildCollapsed(input));
|
||||
});
|
||||
return collapseCompute.getResult(0);
|
||||
return materializeOrComputeUnary(value, collapsedType, rewriter, loc, buildCollapsed);
|
||||
}
|
||||
|
||||
static Value expandBatchDims(Value value,
|
||||
RankedTensorType outputType,
|
||||
size_t batchRank,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
static Value
|
||||
expandBatchDims(Value value, RankedTensorType outputType, size_t batchRank, PatternRewriter& rewriter, Location loc) {
|
||||
if (cast<RankedTensorType>(value.getType()) == outputType)
|
||||
return value;
|
||||
|
||||
SmallVector<ReassociationIndices> reassociation = {
|
||||
ReassociationIndices {},
|
||||
SmallVector<ReassociationIndices> reassociation = {ReassociationIndices {},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank)},
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}
|
||||
};
|
||||
ReassociationIndices {static_cast<int64_t>(batchRank + 1)}};
|
||||
for (size_t dim = 0; dim < batchRank; ++dim)
|
||||
reassociation.front().push_back(static_cast<int64_t>(dim));
|
||||
|
||||
auto expandCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
auto buildExpanded = [&](Value input) -> Value {
|
||||
return tensor::ExpandShapeOp::create(rewriter, loc, outputType, input, reassociation).getResult();
|
||||
};
|
||||
return materializeOrComputeUnary(value, outputType, rewriter, loc, buildExpanded);
|
||||
}
|
||||
|
||||
static Value ensureBatchedTensor(
|
||||
Value value, int64_t batchSize, int64_t rows, int64_t cols, PatternRewriter& rewriter, Location loc) {
|
||||
auto type = cast<RankedTensorType>(value.getType());
|
||||
if (type.getRank() == 3)
|
||||
return value;
|
||||
|
||||
auto batchedType = RankedTensorType::get({batchSize, rows, cols}, type.getElementType(), type.getEncoding());
|
||||
auto buildExpanded = [&](Value input) -> Value {
|
||||
return tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
batchedType,
|
||||
input,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
return expandCompute.getResult(0);
|
||||
};
|
||||
return materializeOrComputeUnary(value, batchedType, rewriter, loc, buildExpanded);
|
||||
}
|
||||
|
||||
static Value extractBatchMatrix(Value value,
|
||||
@@ -112,7 +102,7 @@ static Value extractBatchMatrix(Value value,
|
||||
rewriter.getIndexAttr(batchSize == 1 ? 0 : batchIndex), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes = {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(rows), rewriter.getIndexAttr(cols)};
|
||||
SmallVector<OpFoldResult> strides = {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, 3);
|
||||
auto matrixType = RankedTensorType::get({rows, cols}, type.getElementType());
|
||||
auto buildMatrix = [&](Value input) -> Value {
|
||||
Value slice = tensor::ExtractSliceOp::create(rewriter, loc, sliceType, input, offsets, sizes, strides);
|
||||
@@ -126,14 +116,7 @@ static Value extractBatchMatrix(Value value,
|
||||
});
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
return buildMatrix(value);
|
||||
|
||||
auto batchMatrixCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {matrixType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildMatrix(input));
|
||||
});
|
||||
return batchMatrixCompute.getResult(0);
|
||||
return materializeOrComputeUnary(value, matrixType, rewriter, loc, buildMatrix);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -150,18 +133,7 @@ static Value transposeLastTwoDims(Value value, PatternRewriter& rewriter, Locati
|
||||
perm = {0, 2, 1};
|
||||
}
|
||||
|
||||
auto buildTranspose = [&](Value input) -> Value {
|
||||
return ONNXTransposeOp::create(rewriter, loc, transposedType, input, rewriter.getI64ArrayAttr(perm));
|
||||
};
|
||||
|
||||
if (isHostFoldableValue(value))
|
||||
return buildTranspose(value);
|
||||
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {transposedType}, {}, ValueRange {value}, [&](Value input) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, buildTranspose(input));
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
return transposeMaybeInCompute(value, transposedType, perm, rewriter, loc);
|
||||
}
|
||||
|
||||
static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewriter, Location loc) {
|
||||
@@ -185,28 +157,527 @@ static Value transposeLastTwoDimsInCompute(Value value, PatternRewriter& rewrite
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, PatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
static Value createZeroPaddedTensor(Value value, RankedTensorType resultType, PatternRewriter& rewriter, Location loc) {
|
||||
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||
SmallVector<OpFoldResult> lowPads(sourceType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> highPads;
|
||||
highPads.reserve(sourceType.getRank());
|
||||
for (auto [sourceDim, resultDim] : llvm::zip(sourceType.getShape(), resultType.getShape()))
|
||||
highPads.push_back(rewriter.getIndexAttr(resultDim - sourceDim));
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
auto padOp = tensor::PadOp::create(rewriter, loc, resultType, value, lowPads, highPads);
|
||||
auto* padBlock = new Block();
|
||||
for (int64_t i = 0; i < sourceType.getRank(); ++i)
|
||||
padBlock->addArgument(rewriter.getIndexType(), loc);
|
||||
padOp.getRegion().push_back(padBlock);
|
||||
rewriter.setInsertionPointToStart(padBlock);
|
||||
auto zero = getOrCreateConstant(
|
||||
rewriter, padOp.getOperation(), rewriter.getZeroAttr(sourceType.getElementType()), sourceType.getElementType());
|
||||
tensor::YieldOp::create(rewriter, loc, zero);
|
||||
rewriter.setInsertionPointAfter(padOp);
|
||||
return padOp.getResult();
|
||||
}
|
||||
|
||||
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
static Value createPaddedBatchedInputCompute(Value input,
|
||||
RankedTensorType paddedInputType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
if (inputType == paddedInputType)
|
||||
return input;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||
auto computeOp = createSpatCompute<1>(rewriter, loc, TypeRange {paddedInputType}, {}, input, [&](Value computeInput) {
|
||||
Value paddedInput = createZeroPaddedTensor(computeInput, paddedInputType, rewriter, loc);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, paddedInput);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializePaddedBatchedWeight(
|
||||
Value value, int64_t sourceBatch, int64_t targetBatch, RankedTensorType resultType, PatternRewriter& rewriter) {
|
||||
auto sourceType = cast<RankedTensorType>(value.getType());
|
||||
if (sourceType == resultType)
|
||||
return value;
|
||||
|
||||
auto denseAttr = getHostConstDenseElementsAttr(value);
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
const int64_t sourceRows = sourceType.getRank() == 2 ? sourceType.getDimSize(0) : sourceType.getDimSize(1);
|
||||
const int64_t sourceCols = sourceType.getRank() == 2 ? sourceType.getDimSize(1) : sourceType.getDimSize(2);
|
||||
const int64_t targetRows = resultType.getDimSize(1);
|
||||
const int64_t targetCols = resultType.getDimSize(2);
|
||||
SmallVector<Attribute> sourceValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(resultType.getNumElements(), rewriter.getZeroAttr(resultType.getElementType()));
|
||||
|
||||
for (int64_t batchIdx = 0; batchIdx < targetBatch; ++batchIdx) {
|
||||
const int64_t sourceBatchIdx = sourceType.getRank() == 2 ? 0 : (sourceBatch == 1 ? 0 : batchIdx);
|
||||
const int64_t sourceBatchBase = sourceType.getRank() == 2 ? 0 : sourceBatchIdx * sourceRows * sourceCols;
|
||||
const int64_t targetBatchBase = batchIdx * targetRows * targetCols;
|
||||
for (int64_t row = 0; row < sourceRows; ++row)
|
||||
for (int64_t col = 0; col < sourceCols; ++col)
|
||||
resultValues[targetBatchBase + row * targetCols + col] = sourceValues[sourceBatchBase + row * sourceCols + col];
|
||||
}
|
||||
|
||||
auto resultAttr = DenseElementsAttr::get(resultType, resultValues);
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), resultAttr, resultType);
|
||||
}
|
||||
|
||||
static Value extractBatchedATile(Value a,
|
||||
int64_t sourceBatchCount,
|
||||
Value batch,
|
||||
Value row,
|
||||
Value kOffset,
|
||||
RankedTensorType aTileType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto aSliceType = RankedTensorType::get({1, 1, aTileType.getDimSize(1)}, aTileType.getElementType());
|
||||
SmallVector<OpFoldResult> offsets {
|
||||
sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), row, kOffset};
|
||||
SmallVector<OpFoldResult> sizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(aTileType.getDimSize(1))};
|
||||
auto slice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, aSliceType, a, offsets, sizes, getUnitStrides(rewriter, 3));
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
aTileType,
|
||||
slice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
}
|
||||
|
||||
static Value extractBatchedBTile(Value b,
|
||||
int64_t sourceBatchCount,
|
||||
Value batch,
|
||||
Value kOffset,
|
||||
Value hOffset,
|
||||
RankedTensorType bTileType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto bSliceType =
|
||||
RankedTensorType::get({1, bTileType.getDimSize(0), bTileType.getDimSize(1)}, bTileType.getElementType());
|
||||
SmallVector<OpFoldResult> offsets {
|
||||
sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0)) : OpFoldResult(batch), kOffset, hOffset};
|
||||
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(1),
|
||||
rewriter.getIndexAttr(bTileType.getDimSize(0)),
|
||||
rewriter.getIndexAttr(bTileType.getDimSize(1))};
|
||||
auto slice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, bSliceType, b, offsets, sizes, getUnitStrides(rewriter, 3));
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
bTileType,
|
||||
slice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
}
|
||||
|
||||
static Value getBatchLaneIndex(
|
||||
Value lane, int64_t numOutRows, int64_t numKSlices, int64_t numOutHSlices, PatternRewriter& rewriter, Location loc) {
|
||||
return affineFloorDivConst(
|
||||
rewriter, loc, lane, numOutRows * numKSlices * numOutHSlices, rewriter.getInsertionBlock()->getParentOp());
|
||||
}
|
||||
|
||||
static spatial::SpatComputeBatch createBatchedVmmBatch(Value a,
|
||||
Value b,
|
||||
RankedTensorType aType,
|
||||
int64_t aBatchCount,
|
||||
RankedTensorType bType,
|
||||
int64_t bBatchCount,
|
||||
RankedTensorType partialPiecesType,
|
||||
int64_t numOutRows,
|
||||
int64_t numKSlices,
|
||||
int64_t numOutHSlices,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t laneCount = partialPiecesType.getDimSize(0);
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter,
|
||||
loc,
|
||||
TypeRange {partialPiecesType},
|
||||
laneCount,
|
||||
ValueRange {b},
|
||||
ValueRange {a},
|
||||
[&](detail::SpatComputeBatchBodyArgs args) {
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
Value row = affineModConst(rewriter, loc, args.lane, numOutRows, anchorOp);
|
||||
Value outerLane = affineFloorDivConst(rewriter, loc, args.lane, numOutRows, anchorOp);
|
||||
Value batch = getBatchLaneIndex(args.lane, numOutRows, numKSlices, numOutHSlices, rewriter, loc);
|
||||
Value sliceLane = affineModConst(rewriter, loc, outerLane, numKSlices * numOutHSlices, anchorOp);
|
||||
Value kSlice = affineModConst(rewriter, loc, sliceLane, numKSlices, anchorOp);
|
||||
Value hSlice = affineFloorDivConst(rewriter, loc, sliceLane, numKSlices, anchorOp);
|
||||
Value kOffset = affineMulConst(rewriter, loc, kSlice, crossbarSize.getValue(), anchorOp);
|
||||
Value hOffset = affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), anchorOp);
|
||||
|
||||
auto aTileType =
|
||||
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, aType.getElementType());
|
||||
auto bTileType = RankedTensorType::get(
|
||||
{static_cast<int64_t>(crossbarSize.getValue()), static_cast<int64_t>(crossbarSize.getValue())},
|
||||
bType.getElementType());
|
||||
auto pieceType =
|
||||
RankedTensorType::get({1, static_cast<int64_t>(crossbarSize.getValue())}, partialPiecesType.getElementType());
|
||||
|
||||
Value aTile =
|
||||
extractBatchedATile(args.inputs.front(), aBatchCount, batch, row, kOffset, aTileType, rewriter, loc);
|
||||
Value bTile =
|
||||
extractBatchedBTile(args.weights.front(), bBatchCount, batch, kOffset, hOffset, bTileType, rewriter, loc);
|
||||
Value piece = spatial::SpatVMMOp::create(rewriter, loc, pieceType, bTile, aTile).getResult();
|
||||
|
||||
SmallVector<OpFoldResult> pieceOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> pieceSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
createParallelInsertSliceIntoBatchOutput(
|
||||
rewriter, loc, piece, args.outputs.front(), pieceOffsets, pieceSizes, getUnitStrides(rewriter, 2));
|
||||
});
|
||||
assert(succeeded(batchOp) && "expected batched MatMul VMM construction to succeed");
|
||||
return *batchOp;
|
||||
}
|
||||
|
||||
static Value extractDynamicBatchedBColumn(Value matrix,
|
||||
int64_t sourceBatchCount,
|
||||
Value batch,
|
||||
Value column,
|
||||
RankedTensorType vectorType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto columnSliceType = RankedTensorType::get({1, vectorType.getDimSize(1), 1}, vectorType.getElementType());
|
||||
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
||||
: OpFoldResult(batch),
|
||||
rewriter.getIndexAttr(0),
|
||||
column};
|
||||
SmallVector<OpFoldResult> sizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1)), rewriter.getIndexAttr(1)};
|
||||
SmallVector<OpFoldResult> strides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value columnSlice = tensor::ExtractSliceOp::create(rewriter, loc, columnSliceType, matrix, offsets, sizes, strides);
|
||||
auto collapsedType = RankedTensorType::get({vectorType.getDimSize(1)}, vectorType.getElementType());
|
||||
Value collapsed = tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
collapsedType,
|
||||
columnSlice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1, 2}
|
||||
})
|
||||
.getResult();
|
||||
return tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
vectorType,
|
||||
collapsed,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1}
|
||||
})
|
||||
.getResult();
|
||||
}
|
||||
|
||||
static Value extractDynamicBatchedBRow(Value matrix,
|
||||
int64_t sourceBatchCount,
|
||||
Value batch,
|
||||
Value row,
|
||||
RankedTensorType vectorType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
|
||||
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
||||
: OpFoldResult(batch),
|
||||
row,
|
||||
rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
||||
auto rowSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3));
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
vectorType,
|
||||
rowSlice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
}
|
||||
|
||||
static Value extractDynamicBatchedRowVector(Value matrix,
|
||||
int64_t sourceBatchCount,
|
||||
Value batch,
|
||||
Value row,
|
||||
RankedTensorType vectorType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto rowSliceType = RankedTensorType::get({1, 1, vectorType.getDimSize(1)}, vectorType.getElementType());
|
||||
SmallVector<OpFoldResult> offsets {sourceBatchCount == 1 ? OpFoldResult(rewriter.getIndexAttr(0))
|
||||
: OpFoldResult(batch),
|
||||
row,
|
||||
rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(vectorType.getDimSize(1))};
|
||||
auto rowSlice =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, rowSliceType, matrix, offsets, sizes, getUnitStrides(rewriter, 3));
|
||||
return tensor::CollapseShapeOp::create(rewriter,
|
||||
loc,
|
||||
vectorType,
|
||||
rowSlice,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
}
|
||||
|
||||
static spatial::SpatComputeBatch createBatchedVvdmulBatch(Value a,
|
||||
int64_t aBatchCount,
|
||||
Value b,
|
||||
int64_t bBatchCount,
|
||||
RankedTensorType aType,
|
||||
RankedTensorType bType,
|
||||
RankedTensorType scalarPiecesType,
|
||||
RankedTensorType outType,
|
||||
bool bAlreadyTransposed,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t numBatches = outType.getDimSize(0);
|
||||
const int64_t numOutRows = outType.getDimSize(1);
|
||||
const int64_t numOutCols = outType.getDimSize(2);
|
||||
const int64_t reductionSize = aType.getDimSize(2);
|
||||
const int64_t laneCount = numBatches * numOutRows * numOutCols;
|
||||
auto batchOp = createSpatComputeBatch(
|
||||
rewriter,
|
||||
loc,
|
||||
TypeRange {scalarPiecesType},
|
||||
laneCount,
|
||||
ValueRange {},
|
||||
ValueRange {a, b},
|
||||
[&](detail::SpatComputeBatchBodyArgs args) {
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
Value batch = affineFloorDivConst(rewriter, loc, args.lane, numOutRows * numOutCols, anchorOp);
|
||||
Value batchLane = affineModConst(rewriter, loc, args.lane, numOutRows * numOutCols, anchorOp);
|
||||
Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp);
|
||||
Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp);
|
||||
|
||||
auto vectorType = RankedTensorType::get({1, reductionSize}, aType.getElementType());
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
Value aVector =
|
||||
extractDynamicBatchedRowVector(args.inputs[0], aBatchCount, batch, row, vectorType, rewriter, loc);
|
||||
Value bVector =
|
||||
bAlreadyTransposed
|
||||
? extractDynamicBatchedBRow(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc)
|
||||
: extractDynamicBatchedBColumn(args.inputs[1], bBatchCount, batch, column, vectorType, rewriter, loc);
|
||||
Value scalar = spatial::SpatVVDMulOp::create(rewriter, loc, scalarType, aVector, bVector).getResult();
|
||||
SmallVector<OpFoldResult> outputOffsets {args.lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
createParallelInsertSliceIntoBatchOutput(
|
||||
rewriter, loc, scalar, args.outputs.front(), outputOffsets, scalarSizes, getUnitStrides(rewriter, 2));
|
||||
});
|
||||
assert(succeeded(batchOp) && "expected batched MatMul VVDMul construction to succeed");
|
||||
return *batchOp;
|
||||
}
|
||||
|
||||
static Value createBatchedDynamicOutputCompute(Value scalarPieces,
|
||||
RankedTensorType scalarPiecesType,
|
||||
RankedTensorType outType,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
const int64_t laneCount = scalarPiecesType.getDimSize(0);
|
||||
const int64_t numOutRows = outType.getDimSize(1);
|
||||
const int64_t numOutCols = outType.getDimSize(2);
|
||||
auto scalarType = RankedTensorType::get({1, 1}, outType.getElementType());
|
||||
auto outputScalarType = RankedTensorType::get({1, 1, 1}, outType.getElementType());
|
||||
|
||||
auto computeOp =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, ValueRange {scalarPieces}, [&](Value pieces) {
|
||||
Value outputInit =
|
||||
tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType()).getResult();
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cLaneCount = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), laneCount);
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cLaneCount, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
|
||||
Value lane = loop.getInductionVar();
|
||||
Value outputAcc = loop.getRegionIterArgs().front();
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
Value batch = affineFloorDivConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp);
|
||||
Value batchLane = affineModConst(rewriter, loc, lane, numOutRows * numOutCols, anchorOp);
|
||||
Value row = affineFloorDivConst(rewriter, loc, batchLane, numOutCols, anchorOp);
|
||||
Value column = affineModConst(rewriter, loc, batchLane, numOutCols, anchorOp);
|
||||
SmallVector<OpFoldResult> scalarOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> scalarSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value scalar = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, scalarType, pieces, scalarOffsets, scalarSizes, getUnitStrides(rewriter, 2));
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
outputScalarType,
|
||||
scalar,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0},
|
||||
{1, 2}
|
||||
});
|
||||
SmallVector<OpFoldResult> outputOffsets {batch, row, column};
|
||||
SmallVector<OpFoldResult> outputSizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
scf::YieldOp::create(
|
||||
rewriter,
|
||||
loc,
|
||||
tensor::InsertSliceOp::create(
|
||||
rewriter, loc, expanded, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3))
|
||||
.getResult());
|
||||
|
||||
rewriter.setInsertionPointAfter(loop);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, loop.getResult(0));
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value transposeBatchedOutput(Value value, RankedTensorType outputType, PatternRewriter& rewriter, Location loc) {
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outputType}, {}, ValueRange {value}, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, outputType, input, rewriter.getI64ArrayAttr({0, 2, 1}));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
return transposeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value extractBatchedReductionPiece(Value partialPiecesArg,
|
||||
Value batch,
|
||||
Value hSlice,
|
||||
int64_t kSlice,
|
||||
RankedTensorType pieceType,
|
||||
int64_t numKSlices,
|
||||
int64_t numOutHSlices,
|
||||
int64_t numOutRows,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
Value batchOffset = affineMulConst(rewriter, loc, batch, numOutRows * numKSlices * numOutHSlices, anchorOp);
|
||||
Value hOffset = affineMulConst(rewriter, loc, hSlice, numKSlices * numOutRows, anchorOp);
|
||||
Value kOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), kSlice * numOutRows);
|
||||
Value batchAndHSlice = arith::AddIOp::create(rewriter, loc, batchOffset, hOffset);
|
||||
Value pieceOffset = arith::AddIOp::create(rewriter, loc, batchAndHSlice, kOffset);
|
||||
SmallVector<OpFoldResult> offsets {pieceOffset, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> sizes {rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
return tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, pieceType, partialPiecesArg, offsets, sizes, getUnitStrides(rewriter, 2));
|
||||
}
|
||||
|
||||
static Value reduceBatchedPartialPiecesForHSlice(Value partialPiecesArg,
|
||||
Value batch,
|
||||
Value hSlice,
|
||||
RankedTensorType pieceType,
|
||||
int64_t numKSlices,
|
||||
int64_t numOutHSlices,
|
||||
int64_t numOutRows,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
SmallVector<Value> activePieces;
|
||||
activePieces.reserve(numKSlices);
|
||||
for (int64_t kSlice = 0; kSlice < numKSlices; ++kSlice)
|
||||
activePieces.push_back(extractBatchedReductionPiece(
|
||||
partialPiecesArg, batch, hSlice, kSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, loc));
|
||||
|
||||
while (activePieces.size() > 1) {
|
||||
SmallVector<Value> nextPieces;
|
||||
nextPieces.reserve((activePieces.size() + 1) / 2);
|
||||
for (size_t pieceIndex = 0; pieceIndex + 1 < activePieces.size(); pieceIndex += 2)
|
||||
nextPieces.push_back(
|
||||
spatial::SpatVAddOp::create(rewriter, loc, pieceType, activePieces[pieceIndex], activePieces[pieceIndex + 1])
|
||||
.getResult());
|
||||
if (activePieces.size() % 2 != 0)
|
||||
nextPieces.push_back(activePieces.back());
|
||||
activePieces = std::move(nextPieces);
|
||||
}
|
||||
|
||||
return activePieces.front();
|
||||
}
|
||||
|
||||
static Value createBatchedReductionCompute(Value partialPieces,
|
||||
RankedTensorType partialPiecesType,
|
||||
RankedTensorType outType,
|
||||
RankedTensorType paddedOutType,
|
||||
int64_t numBatches,
|
||||
int64_t numKSlices,
|
||||
PatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, loc, TypeRange {outType}, {}, ValueRange {partialPieces}, [&](Value partialPiecesArg) {
|
||||
const int64_t numOutRows = outType.getDimSize(1);
|
||||
const int64_t numOutHSlices = ceilIntegerDivide(outType.getDimSize(2), crossbarSize.getValue());
|
||||
auto pieceType = RankedTensorType::get({numOutRows, static_cast<int64_t>(crossbarSize.getValue())},
|
||||
partialPiecesType.getElementType());
|
||||
auto outputSliceType = RankedTensorType::get({1, numOutRows, static_cast<int64_t>(crossbarSize.getValue())},
|
||||
partialPiecesType.getElementType());
|
||||
|
||||
Value outputInit =
|
||||
tensor::EmptyOp::create(rewriter, loc, paddedOutType.getShape(), paddedOutType.getElementType()).getResult();
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 1);
|
||||
Value cNumBatches = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numBatches);
|
||||
Value cNumOutHSlices =
|
||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), numOutHSlices);
|
||||
|
||||
auto batchLoop = scf::ForOp::create(rewriter, loc, c0, cNumBatches, c1, ValueRange {outputInit});
|
||||
rewriter.setInsertionPointToStart(batchLoop.getBody());
|
||||
Value batch = batchLoop.getInductionVar();
|
||||
Value batchAcc = batchLoop.getRegionIterArgs().front();
|
||||
|
||||
auto hLoop = scf::ForOp::create(rewriter, loc, c0, cNumOutHSlices, c1, ValueRange {batchAcc});
|
||||
rewriter.setInsertionPointToStart(hLoop.getBody());
|
||||
Value hSlice = hLoop.getInductionVar();
|
||||
Value outputAcc = hLoop.getRegionIterArgs().front();
|
||||
|
||||
Value reduced = reduceBatchedPartialPiecesForHSlice(
|
||||
partialPiecesArg, batch, hSlice, pieceType, numKSlices, numOutHSlices, numOutRows, rewriter, loc);
|
||||
Value expandedReduced = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
outputSliceType,
|
||||
reduced,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
Value hOffset =
|
||||
affineMulConst(rewriter, loc, hSlice, crossbarSize.getValue(), rewriter.getInsertionBlock()->getParentOp());
|
||||
SmallVector<OpFoldResult> outputOffsets {batch, rewriter.getIndexAttr(0), hOffset};
|
||||
SmallVector<OpFoldResult> outputSizes {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(numOutRows), rewriter.getIndexAttr(crossbarSize.getValue())};
|
||||
scf::YieldOp::create(
|
||||
rewriter,
|
||||
loc,
|
||||
tensor::InsertSliceOp::create(
|
||||
rewriter, loc, expandedReduced, outputAcc, outputOffsets, outputSizes, getUnitStrides(rewriter, 3))
|
||||
.getResult());
|
||||
|
||||
rewriter.setInsertionPointAfter(hLoop);
|
||||
scf::YieldOp::create(rewriter, loc, hLoop.getResult(0));
|
||||
|
||||
rewriter.setInsertionPointAfter(batchLoop);
|
||||
Value paddedOutput = batchLoop.getResult(0);
|
||||
Value result = paddedOutput;
|
||||
if (paddedOutType != outType) {
|
||||
SmallVector<OpFoldResult> outputOffsets {
|
||||
rewriter.getIndexAttr(0), rewriter.getIndexAttr(0), rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(numBatches),
|
||||
rewriter.getIndexAttr(outType.getDimSize(1)),
|
||||
rewriter.getIndexAttr(outType.getDimSize(2))};
|
||||
result = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, outType, paddedOutput, outputOffsets, outputSizes, getUnitStrides(rewriter, 3));
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, loc, result);
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
struct MatMulShapeInfo {
|
||||
RankedTensorType lhsType;
|
||||
RankedTensorType rhsType;
|
||||
RankedTensorType outType;
|
||||
SmallVector<int64_t> batchShape;
|
||||
int64_t lhsBatch;
|
||||
int64_t rhsBatch;
|
||||
int64_t batch;
|
||||
int64_t m;
|
||||
int64_t k;
|
||||
int64_t n;
|
||||
};
|
||||
|
||||
static FailureOr<MatMulShapeInfo> analyzeMatMulShape(ONNXMatMulOp matmulOp) {
|
||||
auto lhsType = dyn_cast<RankedTensorType>(matmulOp.getA().getType());
|
||||
auto rhsType = dyn_cast<RankedTensorType>(matmulOp.getB().getType());
|
||||
auto outType = dyn_cast<RankedTensorType>(matmulOp.getY().getType());
|
||||
@@ -215,8 +686,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
return failure();
|
||||
if (lhsType.getRank() < 2 || rhsType.getRank() < 2 || outType.getRank() < 2)
|
||||
return failure();
|
||||
if (!haveStaticPositiveShape(lhsType.getShape()) || !haveStaticPositiveShape(rhsType.getShape())
|
||||
|| !haveStaticPositiveShape(outType.getShape()))
|
||||
if (!hasStaticPositiveShape(lhsType) || !hasStaticPositiveShape(rhsType) || !hasStaticPositiveShape(outType))
|
||||
return failure();
|
||||
|
||||
SmallVector<int64_t> lhsBatchShape(lhsType.getShape().begin(), lhsType.getShape().end() - 2);
|
||||
@@ -224,10 +694,10 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
auto batchShape = inferSupportedBatchShape(lhsBatchShape, rhsBatchShape);
|
||||
if (failed(batchShape))
|
||||
return failure();
|
||||
|
||||
const int64_t lhsBatch = lhsBatchShape.empty() ? 1 : getStaticShapeElementCount(lhsBatchShape);
|
||||
const int64_t rhsBatch = rhsBatchShape.empty() ? 1 : getStaticShapeElementCount(rhsBatchShape);
|
||||
const int64_t batch = batchShape->empty() ? 1 : getStaticShapeElementCount(*batchShape);
|
||||
|
||||
const int64_t m = lhsType.getDimSize(lhsType.getRank() - 2);
|
||||
const int64_t k = lhsType.getDimSize(lhsType.getRank() - 1);
|
||||
const int64_t rhsK = rhsType.getDimSize(rhsType.getRank() - 2);
|
||||
@@ -246,30 +716,38 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
return failure();
|
||||
}
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
bool useTransposedForm = isHostFoldableValue(matmulOp.getA()) && !isHostFoldableValue(matmulOp.getB());
|
||||
return MatMulShapeInfo {lhsType, rhsType, outType, *batchShape, lhsBatch, rhsBatch, batch, m, k, n};
|
||||
}
|
||||
|
||||
Value lhs = collapseBatchDims(matmulOp.getA(), lhsBatch, m, k, rewriter, loc);
|
||||
Value rhs = collapseBatchDims(matmulOp.getB(), rhsBatch, k, n, rewriter, loc);
|
||||
int64_t lhsBatchForGemm = lhsBatch;
|
||||
int64_t rhsBatchForGemm = rhsBatch;
|
||||
int64_t gemmM = m;
|
||||
int64_t gemmK = k;
|
||||
int64_t gemmN = n;
|
||||
struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
||||
if (failed(shapeInfo) || shapeInfo->outType.getRank() != 2)
|
||||
return failure();
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
|
||||
|
||||
Value lhs = collapseBatchDims(matmulOp.getA(), shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc);
|
||||
Value rhs = collapseBatchDims(matmulOp.getB(), shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc);
|
||||
int64_t lhsBatchForGemm = shapeInfo->lhsBatch;
|
||||
int64_t rhsBatchForGemm = shapeInfo->rhsBatch;
|
||||
int64_t gemmM = shapeInfo->m;
|
||||
int64_t gemmK = shapeInfo->k;
|
||||
int64_t gemmN = shapeInfo->n;
|
||||
if (useTransposedForm) {
|
||||
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
||||
lhsBatchForGemm = rhsBatch;
|
||||
lhsBatchForGemm = shapeInfo->rhsBatch;
|
||||
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
||||
rhsBatchForGemm = lhsBatch;
|
||||
gemmM = n;
|
||||
gemmN = m;
|
||||
rhsBatchForGemm = shapeInfo->lhsBatch;
|
||||
gemmM = shapeInfo->n;
|
||||
gemmN = shapeInfo->m;
|
||||
}
|
||||
|
||||
auto gemmType = RankedTensorType::get({gemmM, gemmN}, outType.getElementType());
|
||||
auto batchedOutType = RankedTensorType::get({1, m, n}, outType.getElementType());
|
||||
auto gemmType = RankedTensorType::get({gemmM, gemmN}, shapeInfo->outType.getElementType());
|
||||
Value none = ONNXNoneOp::create(rewriter, loc, rewriter.getNoneType());
|
||||
|
||||
if (outType.getRank() == 2) {
|
||||
Value lhsMatrix = extractBatchMatrix(lhs, /*batchIndex=*/0, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||
Value rhsMatrix = extractBatchMatrix(rhs, /*batchIndex=*/0, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||
Value gemmResult = ONNXGemmOp::create(rewriter,
|
||||
@@ -285,8 +763,9 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
.getY();
|
||||
if (useTransposedForm) {
|
||||
auto transposeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {outType}, {}, gemmResult, [&](Value input) {
|
||||
Value transposed = ONNXTransposeOp::create(rewriter, loc, outType, input, rewriter.getI64ArrayAttr({1, 0}));
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {shapeInfo->outType}, {}, gemmResult, [&](Value input) {
|
||||
Value transposed =
|
||||
ONNXTransposeOp::create(rewriter, loc, shapeInfo->outType, input, rewriter.getI64ArrayAttr({1, 0}));
|
||||
spatial::SpatYieldOp::create(rewriter, loc, transposed);
|
||||
});
|
||||
gemmResult = transposeCompute.getResult(0);
|
||||
@@ -294,48 +773,115 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
rewriter.replaceOp(matmulOp, gemmResult);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
SmallVector<Value> batchResults;
|
||||
batchResults.reserve(batch);
|
||||
for (int64_t batchIdx = 0; batchIdx < batch; batchIdx++) {
|
||||
Value lhsMatrix = extractBatchMatrix(lhs, batchIdx, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||
Value rhsMatrix = extractBatchMatrix(rhs, batchIdx, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||
Value gemmResult = ONNXGemmOp::create(rewriter,
|
||||
loc,
|
||||
gemmType,
|
||||
lhsMatrix,
|
||||
rhsMatrix,
|
||||
none,
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getF32FloatAttr(1.0f),
|
||||
rewriter.getBoolAttr(false),
|
||||
rewriter.getBoolAttr(false))
|
||||
.getY();
|
||||
auto batchResultCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {batchedOutType}, {}, gemmResult, [&](Value input) {
|
||||
Value resultMatrix = input;
|
||||
struct MatMulBatchedToSpatialComputes : OpRewritePattern<ONNXMatMulOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXMatMulOp matmulOp, PatternRewriter& rewriter) const override {
|
||||
auto shapeInfo = analyzeMatMulShape(matmulOp);
|
||||
if (failed(shapeInfo))
|
||||
return failure();
|
||||
if (shapeInfo->outType.getRank() == 2)
|
||||
return failure();
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
bool useTransposedForm = isCompileTimeComputable(matmulOp.getA()) && !isCompileTimeComputable(matmulOp.getB());
|
||||
|
||||
Value lhs = collapseBatchDims(matmulOp.getA(), shapeInfo->lhsBatch, shapeInfo->m, shapeInfo->k, rewriter, loc);
|
||||
Value rhs = collapseBatchDims(matmulOp.getB(), shapeInfo->rhsBatch, shapeInfo->k, shapeInfo->n, rewriter, loc);
|
||||
int64_t lhsBatchForGemm = shapeInfo->lhsBatch;
|
||||
int64_t rhsBatchForGemm = shapeInfo->rhsBatch;
|
||||
int64_t gemmM = shapeInfo->m;
|
||||
int64_t gemmK = shapeInfo->k;
|
||||
int64_t gemmN = shapeInfo->n;
|
||||
if (useTransposedForm) {
|
||||
resultMatrix = ONNXTransposeOp::create(rewriter,
|
||||
loc,
|
||||
RankedTensorType::get({m, n}, outType.getElementType()),
|
||||
input,
|
||||
rewriter.getI64ArrayAttr({1, 0}));
|
||||
}
|
||||
Value expanded = tensor::ExpandShapeOp::create(rewriter,
|
||||
loc,
|
||||
batchedOutType,
|
||||
resultMatrix,
|
||||
SmallVector<ReassociationIndices> {
|
||||
{0, 1},
|
||||
{2}
|
||||
});
|
||||
spatial::SpatYieldOp::create(rewriter, loc, expanded);
|
||||
});
|
||||
batchResults.push_back(batchResultCompute.getResult(0));
|
||||
lhs = transposeLastTwoDimsInCompute(matmulOp.getB(), rewriter, loc);
|
||||
lhsBatchForGemm = shapeInfo->rhsBatch;
|
||||
rhs = transposeLastTwoDims(matmulOp.getA(), rewriter, loc);
|
||||
rhsBatchForGemm = shapeInfo->lhsBatch;
|
||||
gemmM = shapeInfo->n;
|
||||
gemmN = shapeInfo->m;
|
||||
}
|
||||
|
||||
Value result = concatValues(batchResults, /*axis=*/0, rewriter, loc);
|
||||
result = expandBatchDims(result, outType, batchShape->size(), rewriter, loc);
|
||||
lhs = ensureBatchedTensor(lhs, lhsBatchForGemm, gemmM, gemmK, rewriter, loc);
|
||||
rhs = ensureBatchedTensor(rhs, rhsBatchForGemm, gemmK, gemmN, rewriter, loc);
|
||||
auto lhsBatchedType = cast<RankedTensorType>(lhs.getType());
|
||||
auto rhsBatchedType = cast<RankedTensorType>(rhs.getType());
|
||||
auto directOutType = RankedTensorType::get({shapeInfo->batch, gemmM, gemmN}, shapeInfo->outType.getElementType());
|
||||
|
||||
if (isCompileTimeComputable(rhs)) {
|
||||
const int64_t numKSlices = ceilIntegerDivide(gemmK, crossbarSize.getValue());
|
||||
const int64_t numOutHSlices = ceilIntegerDivide(gemmN, crossbarSize.getValue());
|
||||
const int64_t paddedReductionSize = numKSlices * static_cast<int64_t>(crossbarSize.getValue());
|
||||
const int64_t paddedOutCols = numOutHSlices * static_cast<int64_t>(crossbarSize.getValue());
|
||||
auto paddedLhsType = RankedTensorType::get(
|
||||
{lhsBatchForGemm, gemmM, paddedReductionSize}, lhsBatchedType.getElementType(), lhsBatchedType.getEncoding());
|
||||
auto paddedRhsType = RankedTensorType::get({shapeInfo->batch, paddedReductionSize, paddedOutCols},
|
||||
rhsBatchedType.getElementType(),
|
||||
rhsBatchedType.getEncoding());
|
||||
auto paddedOutType =
|
||||
RankedTensorType::get({shapeInfo->batch, gemmM, paddedOutCols}, shapeInfo->outType.getElementType());
|
||||
|
||||
auto paddedRhs = materializePaddedBatchedWeight(rhs, rhsBatchForGemm, shapeInfo->batch, paddedRhsType, rewriter);
|
||||
if (succeeded(paddedRhs)) {
|
||||
Value paddedLhs = createPaddedBatchedInputCompute(lhs, paddedLhsType, rewriter, loc);
|
||||
const int64_t laneCount = shapeInfo->batch * gemmM * numKSlices * numOutHSlices;
|
||||
auto partialPiecesType = RankedTensorType::get({laneCount, static_cast<int64_t>(crossbarSize.getValue())},
|
||||
shapeInfo->outType.getElementType());
|
||||
auto batchOp = createBatchedVmmBatch(paddedLhs,
|
||||
*paddedRhs,
|
||||
paddedLhsType,
|
||||
lhsBatchForGemm,
|
||||
paddedRhsType,
|
||||
rhsBatchForGemm,
|
||||
partialPiecesType,
|
||||
gemmM,
|
||||
numKSlices,
|
||||
numOutHSlices,
|
||||
rewriter,
|
||||
loc);
|
||||
Value result = createBatchedReductionCompute(batchOp.getResult(0),
|
||||
partialPiecesType,
|
||||
directOutType,
|
||||
paddedOutType,
|
||||
shapeInfo->batch,
|
||||
numKSlices,
|
||||
rewriter,
|
||||
loc);
|
||||
if (useTransposedForm)
|
||||
result = transposeBatchedOutput(
|
||||
result,
|
||||
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
|
||||
rewriter,
|
||||
loc);
|
||||
result = expandBatchDims(result, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
const int64_t laneCount = shapeInfo->batch * gemmM * gemmN;
|
||||
auto scalarPiecesType = RankedTensorType::get({laneCount, 1}, shapeInfo->outType.getElementType());
|
||||
auto batchOp = createBatchedVvdmulBatch(lhs,
|
||||
lhsBatchForGemm,
|
||||
rhs,
|
||||
rhsBatchForGemm,
|
||||
lhsBatchedType,
|
||||
rhsBatchedType,
|
||||
scalarPiecesType,
|
||||
directOutType,
|
||||
false,
|
||||
rewriter,
|
||||
loc);
|
||||
Value result =
|
||||
createBatchedDynamicOutputCompute(batchOp.getResult(0), scalarPiecesType, directOutType, rewriter, loc);
|
||||
if (useTransposedForm)
|
||||
result = transposeBatchedOutput(
|
||||
result,
|
||||
RankedTensorType::get({shapeInfo->batch, shapeInfo->m, shapeInfo->n}, shapeInfo->outType.getElementType()),
|
||||
rewriter,
|
||||
loc);
|
||||
result = expandBatchDims(result, shapeInfo->outType, shapeInfo->batchShape.size(), rewriter, loc);
|
||||
rewriter.replaceOp(matmulOp, result);
|
||||
return success();
|
||||
}
|
||||
@@ -344,7 +890,7 @@ struct MatMulToGemm : OpRewritePattern<ONNXMatMulOp> {
|
||||
} // namespace
|
||||
|
||||
void populateMatMulRewritePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.insert<MatMulToGemm>(ctx);
|
||||
patterns.insert<MatMulToGemm, MatMulBatchedToSpatialComputes>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/AffineUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -16,26 +19,6 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static SmallVector<int64_t> normalizeAxes(ArrayAttr axesAttr, int64_t rank) {
|
||||
SmallVector<int64_t> normalizedAxes;
|
||||
if (!axesAttr) {
|
||||
normalizedAxes.reserve(rank);
|
||||
for (int64_t axis = 0; axis < rank; axis++)
|
||||
normalizedAxes.push_back(axis);
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
normalizedAxes.reserve(axesAttr.size());
|
||||
for (Attribute attr : axesAttr) {
|
||||
int64_t axis = cast<IntegerAttr>(attr).getInt();
|
||||
normalizedAxes.push_back(axis >= 0 ? axis : rank + axis);
|
||||
}
|
||||
|
||||
llvm::sort(normalizedAxes);
|
||||
normalizedAxes.erase(std::unique(normalizedAxes.begin(), normalizedAxes.end()), normalizedAxes.end());
|
||||
return normalizedAxes;
|
||||
}
|
||||
|
||||
static SmallVector<bool> buildReducedAxesMask(ArrayRef<int64_t> axes, int64_t rank) {
|
||||
SmallVector<bool> reducedAxes(rank, false);
|
||||
for (int64_t axis : axes) {
|
||||
@@ -50,6 +33,184 @@ static RankedTensorType getAllOnesType(RankedTensorType inputType, Type elementT
|
||||
return RankedTensorType::get(SmallVector<int64_t>(inputType.getRank(), 1), elementType);
|
||||
}
|
||||
|
||||
static RankedTensorType getKeepdimsType(RankedTensorType inputType, Type elementType, ArrayRef<bool> reducedAxes) {
|
||||
SmallVector<int64_t> shape;
|
||||
shape.reserve(inputType.getRank());
|
||||
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
|
||||
shape.push_back(isReduced ? 1 : dim);
|
||||
return RankedTensorType::get(shape, elementType, inputType.getEncoding());
|
||||
}
|
||||
|
||||
static RankedTensorType getCompactKeptType(RankedTensorType inputType, Type elementType, ArrayRef<bool> reducedAxes) {
|
||||
SmallVector<int64_t> shape;
|
||||
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
|
||||
if (!isReduced)
|
||||
shape.push_back(dim);
|
||||
return RankedTensorType::get(shape, elementType, inputType.getEncoding());
|
||||
}
|
||||
|
||||
static RankedTensorType getReducedSliceType(RankedTensorType inputType, ArrayRef<bool> reducedAxes) {
|
||||
SmallVector<int64_t> shape;
|
||||
shape.reserve(inputType.getRank());
|
||||
for (auto [dim, isReduced] : llvm::zip_equal(inputType.getShape(), reducedAxes))
|
||||
shape.push_back(isReduced ? dim : 1);
|
||||
return RankedTensorType::get(shape, inputType.getElementType(), inputType.getEncoding());
|
||||
}
|
||||
|
||||
static RankedTensorType getLanePackedKeepdimsType(int64_t laneCount, RankedTensorType leafType) {
|
||||
SmallVector<int64_t> shape(leafType.getShape().begin(), leafType.getShape().end());
|
||||
shape.front() = laneCount;
|
||||
return RankedTensorType::get(shape, leafType.getElementType(), leafType.getEncoding());
|
||||
}
|
||||
|
||||
static SmallVector<int64_t> getKeptAxes(ArrayRef<bool> reducedAxes) {
|
||||
SmallVector<int64_t> keptAxes;
|
||||
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes))
|
||||
if (!isReduced)
|
||||
keptAxes.push_back(static_cast<int64_t>(axis));
|
||||
return keptAxes;
|
||||
}
|
||||
|
||||
static Value
|
||||
computeLaneIndex(Value lane, int64_t stride, int64_t dimSize, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
if (dimSize == 1)
|
||||
return getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
|
||||
MLIRContext* context = rewriter.getContext();
|
||||
AffineExpr d0 = getAffineDimExpr(0, context);
|
||||
AffineExpr expr = d0;
|
||||
if (stride != 1)
|
||||
expr = expr.floorDiv(stride);
|
||||
if (dimSize != 1)
|
||||
expr = expr % dimSize;
|
||||
return createOrFoldAffineApply(rewriter, loc, expr, ValueRange {lane}, rewriter.getInsertionBlock()->getParentOp());
|
||||
}
|
||||
|
||||
static FailureOr<Value> buildReduceMeanKeepdimsBatch(Value input,
|
||||
ArrayRef<bool> reducedAxes,
|
||||
RankedTensorType batchType,
|
||||
RankedTensorType leafType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
auto sliceType = getReducedSliceType(inputType, reducedAxes);
|
||||
SmallVector<int64_t> keptAxes = getKeptAxes(reducedAxes);
|
||||
|
||||
int64_t laneCount = 1;
|
||||
SmallVector<int64_t> keptAxisStrides(keptAxes.size(), 1);
|
||||
for (int64_t index = static_cast<int64_t>(keptAxes.size()) - 1; index >= 0; --index) {
|
||||
keptAxisStrides[index] = laneCount;
|
||||
int64_t dimSize = inputType.getDimSize(keptAxes[index]);
|
||||
if (dimSize <= 0)
|
||||
return failure();
|
||||
if (laneCount > std::numeric_limits<int32_t>::max() / dimSize)
|
||||
return failure();
|
||||
laneCount *= dimSize;
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult> sliceOffsets;
|
||||
SmallVector<OpFoldResult> sliceSizes;
|
||||
SmallVector<OpFoldResult> insertOffsets;
|
||||
SmallVector<OpFoldResult> insertSizes(inputType.getRank(), rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> unitStrides = getUnitStrides(rewriter, inputType.getRank());
|
||||
sliceOffsets.reserve(inputType.getRank());
|
||||
sliceSizes.reserve(inputType.getRank());
|
||||
insertOffsets.reserve(inputType.getRank());
|
||||
|
||||
auto batchOp =
|
||||
createSpatComputeBatch(rewriter,
|
||||
loc,
|
||||
TypeRange {batchType},
|
||||
laneCount,
|
||||
{},
|
||||
ValueRange {input},
|
||||
[&](detail::SpatComputeBatchBodyArgs args) {
|
||||
size_t keptAxisIndex = 0;
|
||||
sliceOffsets.clear();
|
||||
sliceSizes.clear();
|
||||
insertOffsets.clear();
|
||||
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
|
||||
if (isReduced) {
|
||||
sliceOffsets.push_back(rewriter.getIndexAttr(0));
|
||||
sliceSizes.push_back(rewriter.getIndexAttr(inputType.getDimSize(axis)));
|
||||
continue;
|
||||
}
|
||||
|
||||
Value axisIndex = computeLaneIndex(
|
||||
args.lane, keptAxisStrides[keptAxisIndex], inputType.getDimSize(axis), rewriter, loc);
|
||||
++keptAxisIndex;
|
||||
sliceOffsets.push_back(axisIndex);
|
||||
sliceSizes.push_back(rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
insertOffsets.push_back(args.lane);
|
||||
insertOffsets.append(inputType.getRank() - 1, rewriter.getIndexAttr(0));
|
||||
|
||||
Value slice = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, sliceType, args.inputs.front(), sliceOffsets, sliceSizes, unitStrides);
|
||||
Value reduced = spatial::SpatVAvgOp::create(rewriter, loc, leafType, slice).getResult();
|
||||
createParallelInsertSliceIntoBatchOutput(
|
||||
rewriter, loc, reduced, args.outputs.front(), insertOffsets, insertSizes, unitStrides);
|
||||
});
|
||||
if (failed(batchOp))
|
||||
return failure();
|
||||
return (*batchOp).getResult(0);
|
||||
}
|
||||
|
||||
static Value buildKeepdimsFromLanePackedBatch(Value batchValue,
|
||||
RankedTensorType keepdimsType,
|
||||
RankedTensorType compactKeptType,
|
||||
ArrayRef<bool> reducedAxes,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto batchType = cast<RankedTensorType>(batchValue.getType());
|
||||
if (batchType == keepdimsType)
|
||||
return batchValue;
|
||||
|
||||
SmallVector<ReassociationIndices> collapseToFlat {{}};
|
||||
for (int64_t axis = 0; axis < batchType.getRank(); ++axis)
|
||||
collapseToFlat.front().push_back(axis);
|
||||
|
||||
SmallVector<ReassociationIndices> expandFlatToCompact(1);
|
||||
for (int64_t axis = 0; axis < compactKeptType.getRank(); ++axis)
|
||||
expandFlatToCompact.front().push_back(axis);
|
||||
|
||||
SmallVector<ReassociationIndices> expandCompactToKeepdims;
|
||||
ReassociationIndices pendingLeadingReducedAxes;
|
||||
for (auto [axis, isReduced] : llvm::enumerate(reducedAxes)) {
|
||||
if (isReduced) {
|
||||
if (expandCompactToKeepdims.empty())
|
||||
pendingLeadingReducedAxes.push_back(axis);
|
||||
else
|
||||
expandCompactToKeepdims.back().push_back(axis);
|
||||
continue;
|
||||
}
|
||||
|
||||
expandCompactToKeepdims.emplace_back();
|
||||
auto& group = expandCompactToKeepdims.back();
|
||||
group.append(pendingLeadingReducedAxes.begin(), pendingLeadingReducedAxes.end());
|
||||
pendingLeadingReducedAxes.clear();
|
||||
group.push_back(axis);
|
||||
}
|
||||
if (!pendingLeadingReducedAxes.empty())
|
||||
expandCompactToKeepdims.back().append(pendingLeadingReducedAxes.begin(), pendingLeadingReducedAxes.end());
|
||||
|
||||
auto reshapeCompute =
|
||||
createSpatCompute<1>(rewriter, loc, TypeRange {keepdimsType}, {}, ValueRange {batchValue}, [&](Value input) {
|
||||
auto flatType =
|
||||
RankedTensorType::get({batchType.getDimSize(0)}, batchType.getElementType(), batchType.getEncoding());
|
||||
Value flat = tensor::CollapseShapeOp::create(rewriter, loc, flatType, input, collapseToFlat);
|
||||
Value compact = flat;
|
||||
if (compactKeptType != flatType)
|
||||
compact = tensor::ExpandShapeOp::create(rewriter, loc, compactKeptType, flat, expandFlatToCompact);
|
||||
Value keepdims = compact;
|
||||
if (keepdimsType != compactKeptType)
|
||||
keepdims = tensor::ExpandShapeOp::create(rewriter, loc, keepdimsType, compact, expandCompactToKeepdims);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, keepdims);
|
||||
});
|
||||
return reshapeCompute.getResult(0);
|
||||
}
|
||||
|
||||
static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<bool> reducedAxes) {
|
||||
SmallVector<ReassociationIndices> reassociation;
|
||||
ReassociationIndices currentGroup;
|
||||
@@ -72,56 +233,6 @@ static SmallVector<ReassociationIndices> buildCollapseReassociation(ArrayRef<boo
|
||||
return reassociation;
|
||||
}
|
||||
|
||||
static Value
|
||||
createAverageCompute(Value input, RankedTensorType resultType, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
constexpr size_t numInputs = 1;
|
||||
auto computeOp = createSpatCompute<numInputs>(rewriter, loc, resultType, {}, ValueRange {input}, [&](Value x) {
|
||||
auto avgOp = spatial::SpatVAvgOp::create(rewriter, loc, resultType, x);
|
||||
spatial::SpatYieldOp::create(rewriter, loc, avgOp.getResult());
|
||||
});
|
||||
return computeOp.getResult(0);
|
||||
}
|
||||
|
||||
static Value concatValues(ValueRange inputs, int64_t axis, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto firstType = cast<RankedTensorType>(inputs.front().getType());
|
||||
SmallVector<int64_t> outputShape(firstType.getShape().begin(), firstType.getShape().end());
|
||||
int64_t concatDimSize = 0;
|
||||
for (Value input : inputs)
|
||||
concatDimSize += cast<RankedTensorType>(input.getType()).getDimSize(axis);
|
||||
outputShape[axis] = concatDimSize;
|
||||
auto resultType = RankedTensorType::get(outputShape, firstType.getElementType(), firstType.getEncoding());
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue))
|
||||
return createSpatConcat(rewriter, loc, axis, inputs);
|
||||
|
||||
auto concatCompute = createSpatCompute(rewriter, loc, TypeRange {resultType}, {}, inputs, [&](ValueRange args) {
|
||||
spatial::SpatYieldOp::create(rewriter, loc, createSpatConcat(rewriter, loc, axis, args));
|
||||
});
|
||||
return concatCompute.getResult(0);
|
||||
}
|
||||
|
||||
static Value buildReduceMeanKeepdims(Value input,
|
||||
ArrayRef<bool> reducedAxes,
|
||||
int64_t axis,
|
||||
RankedTensorType leafType,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
int64_t rank = cast<RankedTensorType>(input.getType()).getRank();
|
||||
if (axis == rank)
|
||||
return createAverageCompute(input, leafType, rewriter, loc);
|
||||
|
||||
if (reducedAxes[axis])
|
||||
return buildReduceMeanKeepdims(input, reducedAxes, axis + 1, leafType, rewriter, loc);
|
||||
|
||||
SmallVector<Value> slices = sliceTensor(input, axis, /*sliceSize=*/1, rewriter, loc);
|
||||
SmallVector<Value> reducedSlices;
|
||||
reducedSlices.reserve(slices.size());
|
||||
for (Value slice : slices)
|
||||
reducedSlices.push_back(buildReduceMeanKeepdims(slice, reducedAxes, axis + 1, leafType, rewriter, loc));
|
||||
|
||||
return concatValues(reducedSlices, axis, rewriter, loc);
|
||||
}
|
||||
|
||||
static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<bool> reducedAxes,
|
||||
@@ -129,13 +240,13 @@ static Value squeezeReducedAxes(Value keepdimsValue,
|
||||
Location loc) {
|
||||
if (resultType.getRank() == 0) {
|
||||
SmallVector<Value> indices(cast<RankedTensorType>(keepdimsValue.getType()).getRank(),
|
||||
arith::ConstantIndexOp::create(rewriter, loc, 0));
|
||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0));
|
||||
Value element = tensor::ExtractOp::create(rewriter, loc, keepdimsValue, indices);
|
||||
return tensor::FromElementsOp::create(rewriter, loc, resultType, ValueRange {element});
|
||||
}
|
||||
|
||||
auto reassociation = buildCollapseReassociation(reducedAxes);
|
||||
if (isHostFoldableValue(keepdimsValue))
|
||||
if (isCompileTimeComputable(keepdimsValue))
|
||||
return tensor::CollapseShapeOp::create(rewriter, loc, resultType, keepdimsValue, reassociation).getResult();
|
||||
|
||||
auto squeezeCompute =
|
||||
@@ -156,16 +267,33 @@ struct ReduceMeanToSpatialCompute : OpConversionPattern<ONNXReduceMeanV13Op> {
|
||||
auto resultType = dyn_cast<RankedTensorType>(reduceMeanOp.getReduced().getType());
|
||||
if (!inputType || !resultType || !inputType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
if (inputType.getRank() == 0) {
|
||||
rewriter.replaceOp(reduceMeanOp, adaptor.getData());
|
||||
return success();
|
||||
}
|
||||
|
||||
SmallVector<int64_t> axes = normalizeAxes(reduceMeanOp.getAxesAttr(), inputType.getRank());
|
||||
SmallVector<bool> reducedAxes = buildReducedAxesMask(axes, inputType.getRank());
|
||||
auto axes = normalizeAxesChecked(std::optional<ArrayAttr>(reduceMeanOp.getAxesAttr()), inputType.getRank());
|
||||
if (failed(axes))
|
||||
return failure();
|
||||
SmallVector<bool> reducedAxes = buildReducedAxesMask(*axes, inputType.getRank());
|
||||
if (reducedAxes.empty() && inputType.getRank() != 0)
|
||||
return failure();
|
||||
|
||||
Location loc = reduceMeanOp.getLoc();
|
||||
RankedTensorType leafType = getAllOnesType(inputType, resultType.getElementType());
|
||||
RankedTensorType compactKeptType = getCompactKeptType(inputType, resultType.getElementType(), reducedAxes);
|
||||
RankedTensorType keepdimsType = getKeepdimsType(inputType, resultType.getElementType(), reducedAxes);
|
||||
int64_t laneCount = 1;
|
||||
for (int64_t dim : compactKeptType.getShape())
|
||||
laneCount *= dim;
|
||||
RankedTensorType batchType = getLanePackedKeepdimsType(laneCount, leafType);
|
||||
|
||||
auto lanePackedKeepdims =
|
||||
buildReduceMeanKeepdimsBatch(adaptor.getData(), reducedAxes, batchType, leafType, rewriter, loc);
|
||||
if (failed(lanePackedKeepdims))
|
||||
return failure();
|
||||
Value reducedKeepdims =
|
||||
buildReduceMeanKeepdims(adaptor.getData(), reducedAxes, /*axis=*/0, leafType, rewriter, loc);
|
||||
buildKeepdimsFromLanePackedBatch(*lanePackedKeepdims, keepdimsType, compactKeptType, reducedAxes, rewriter, loc);
|
||||
|
||||
if (reduceMeanOp.getKeepdims() != 0) {
|
||||
rewriter.replaceOp(reduceMeanOp, reducedKeepdims);
|
||||
|
||||
@@ -23,43 +23,26 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
template <typename ArrayAttrT>
|
||||
static int64_t getI64(ArrayAttrT arrayAttr, size_t index) {
|
||||
return cast<IntegerAttr>(arrayAttr[index]).getInt();
|
||||
}
|
||||
|
||||
template <typename ArrayAttrT>
|
||||
static int64_t getOptionalI64(std::optional<ArrayAttrT> arrayAttr, size_t index, int64_t defaultValue) {
|
||||
return arrayAttr ? getI64(*arrayAttr, index) : defaultValue;
|
||||
}
|
||||
|
||||
static Value materializeContiguousTile(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
||||
static Value materializeTileTensor(ConversionPatternRewriter& rewriter, Location loc, Value tile) {
|
||||
auto tileType = cast<RankedTensorType>(tile.getType());
|
||||
Value empty = tensor::EmptyOp::create(rewriter, loc, tileType.getShape(), tileType.getElementType());
|
||||
|
||||
SmallVector<OpFoldResult> offsets(tileType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(tileType.getRank());
|
||||
for (int64_t dimSize : tileType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dimSize));
|
||||
SmallVector<OpFoldResult> strides(tileType.getRank(), rewriter.getIndexAttr(1));
|
||||
|
||||
return tensor::InsertSliceOp::create(rewriter, loc, tile, empty, offsets, sizes, strides);
|
||||
return insertStaticSlice(rewriter, loc, tile, empty, getZeroOffsets(rewriter, tileType.getRank()));
|
||||
}
|
||||
|
||||
static Value
|
||||
createPoolFillElement(ConversionPatternRewriter& rewriter, Location loc, Type elementType, bool useMinimumValue) {
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
if (!useMinimumValue)
|
||||
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getZeroAttr(elementType));
|
||||
return getOrCreateConstant(rewriter, anchorOp, rewriter.getZeroAttr(elementType), elementType);
|
||||
|
||||
if (auto floatType = dyn_cast<FloatType>(elementType)) {
|
||||
auto minValue = llvm::APFloat::getInf(floatType.getFloatSemantics(), /*Negative=*/true);
|
||||
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getFloatAttr(floatType, minValue));
|
||||
return getOrCreateConstant(rewriter, anchorOp, rewriter.getFloatAttr(floatType, minValue), elementType);
|
||||
}
|
||||
|
||||
if (auto integerType = dyn_cast<IntegerType>(elementType)) {
|
||||
auto minValue = llvm::APInt::getSignedMinValue(integerType.getWidth());
|
||||
return arith::ConstantOp::create(rewriter, loc, elementType, rewriter.getIntegerAttr(integerType, minValue));
|
||||
return getOrCreateConstant(rewriter, anchorOp, rewriter.getIntegerAttr(integerType, minValue), elementType);
|
||||
}
|
||||
|
||||
llvm_unreachable("unsupported pool element type");
|
||||
@@ -166,7 +149,7 @@ static FailureOr<Value> createAverageScaleTensor(ConversionPatternRewriter& rewr
|
||||
}
|
||||
|
||||
auto scaleAttr = DenseElementsAttr::get(scaleType, scaleValues);
|
||||
return arith::ConstantOp::create(rewriter, loc, scaleType, scaleAttr).getResult();
|
||||
return getOrCreateConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scaleAttr, scaleType);
|
||||
}
|
||||
|
||||
template <typename PoolOp>
|
||||
@@ -197,12 +180,12 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
const int64_t inputWidth = xType.getDimSize(3);
|
||||
const int64_t outputHeight = outType.getDimSize(2);
|
||||
const int64_t outputWidth = outType.getDimSize(3);
|
||||
const int64_t kernelHeight = getI64(kernelAttr, 0);
|
||||
const int64_t kernelWidth = getI64(kernelAttr, 1);
|
||||
const int64_t strideHeight = getOptionalI64(poolOp.getStrides(), 0, 1);
|
||||
const int64_t strideWidth = getOptionalI64(poolOp.getStrides(), 1, 1);
|
||||
const int64_t dilationHeight = getOptionalI64(poolOp.getDilations(), 0, 1);
|
||||
const int64_t dilationWidth = getOptionalI64(poolOp.getDilations(), 1, 1);
|
||||
const int64_t kernelHeight = getI64Attr(kernelAttr, 0);
|
||||
const int64_t kernelWidth = getI64Attr(kernelAttr, 1);
|
||||
const int64_t strideHeight = getOptionalI64Attr(poolOp.getStrides(), 0, 1);
|
||||
const int64_t strideWidth = getOptionalI64Attr(poolOp.getStrides(), 1, 1);
|
||||
const int64_t dilationHeight = getOptionalI64Attr(poolOp.getDilations(), 0, 1);
|
||||
const int64_t dilationWidth = getOptionalI64Attr(poolOp.getDilations(), 1, 1);
|
||||
|
||||
int64_t padTop = 0;
|
||||
int64_t padLeft = 0;
|
||||
@@ -212,10 +195,10 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
if (auto padsAttr = poolOp.getPads()) {
|
||||
if (padsAttr->size() != 4)
|
||||
return rewriter.notifyMatchFailure(poolOp, "pads must have four elements.");
|
||||
padTop = getI64(*padsAttr, 0);
|
||||
padLeft = getI64(*padsAttr, 1);
|
||||
padBottom = getI64(*padsAttr, 2);
|
||||
padRight = getI64(*padsAttr, 3);
|
||||
padTop = getI64Attr(*padsAttr, 0);
|
||||
padLeft = getI64Attr(*padsAttr, 1);
|
||||
padBottom = getI64Attr(*padsAttr, 2);
|
||||
padRight = getI64Attr(*padsAttr, 3);
|
||||
}
|
||||
else {
|
||||
StringRef autoPad = poolOp.getAutoPad();
|
||||
@@ -283,13 +266,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
createPaddedPoolInput(rewriter, loc, poolOp, xArg, xType, padTop, padLeft, padBottom, padRight);
|
||||
Value pooledOutputInit = tensor::EmptyOp::create(rewriter, loc, outType.getShape(), outType.getElementType());
|
||||
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
Value cOutputPatchCount = arith::ConstantIndexOp::create(rewriter, loc, outputPatchCount);
|
||||
Value cOutputPixelsPerBatch = arith::ConstantIndexOp::create(rewriter, loc, outputHeight * outputWidth);
|
||||
Value cOutputWidth = arith::ConstantIndexOp::create(rewriter, loc, outputWidth);
|
||||
Value cStrideHeight = arith::ConstantIndexOp::create(rewriter, loc, strideHeight);
|
||||
Value cStrideWidth = arith::ConstantIndexOp::create(rewriter, loc, strideWidth);
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
||||
Value cOutputPatchCount = getOrCreateIndexConstant(rewriter, anchorOp, outputPatchCount);
|
||||
Value cOutputPixelsPerBatch = getOrCreateIndexConstant(rewriter, anchorOp, outputHeight * outputWidth);
|
||||
Value cOutputWidth = getOrCreateIndexConstant(rewriter, anchorOp, outputWidth);
|
||||
Value cStrideHeight = getOrCreateIndexConstant(rewriter, anchorOp, strideHeight);
|
||||
Value cStrideWidth = getOrCreateIndexConstant(rewriter, anchorOp, strideWidth);
|
||||
|
||||
auto outputLoop = scf::ForOp::create(rewriter, loc, c0, cOutputPatchCount, c1, ValueRange {pooledOutputInit});
|
||||
rewriter.setInsertionPointToStart(outputLoop.getBody());
|
||||
@@ -314,14 +298,14 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
for (int64_t kernelH = 0; kernelH < kernelHeight; ++kernelH) {
|
||||
Value paddedInH = windowBaseH;
|
||||
if (kernelH * dilationHeight != 0) {
|
||||
Value kernelHOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelH * dilationHeight);
|
||||
Value kernelHOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelH * dilationHeight);
|
||||
paddedInH = arith::AddIOp::create(rewriter, loc, paddedInH, kernelHOffset);
|
||||
}
|
||||
|
||||
for (int64_t kernelW = 0; kernelW < kernelWidth; ++kernelW) {
|
||||
Value paddedInW = windowBaseW;
|
||||
if (kernelW * dilationWidth != 0) {
|
||||
Value kernelWOffset = arith::ConstantIndexOp::create(rewriter, loc, kernelW * dilationWidth);
|
||||
Value kernelWOffset = getOrCreateIndexConstant(rewriter, anchorOp, kernelW * dilationWidth);
|
||||
paddedInW = arith::AddIOp::create(rewriter, loc, paddedInW, kernelWOffset);
|
||||
}
|
||||
|
||||
@@ -335,7 +319,7 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value windowValue =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, tileType, paddedInput, offsets, sizes, strides);
|
||||
windowValue = materializeContiguousTile(rewriter, loc, windowValue);
|
||||
windowValue = materializeTileTensor(rewriter, loc, windowValue);
|
||||
reducedWindow = ReduceOp::create(rewriter, loc, tileType, reducedWindow, windowValue);
|
||||
}
|
||||
}
|
||||
@@ -351,7 +335,7 @@ struct PoolToSpatialComputeBase : public OpConversionPattern<PoolOp> {
|
||||
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value scaleSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, tileType, averageScaleTensor, scaleOffsets, scaleSizes, scaleStrides);
|
||||
scaleSlice = materializeContiguousTile(rewriter, loc, scaleSlice);
|
||||
scaleSlice = materializeTileTensor(rewriter, loc, scaleSlice);
|
||||
reducedWindow = spatial::SpatVMulOp::create(rewriter, loc, tileType, reducedWindow, scaleSlice);
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -13,16 +13,6 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
||||
|
||||
static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> shape, ArrayRef<int64_t> permutation) {
|
||||
SmallVector<int64_t> permutedShape;
|
||||
permutedShape.reserve(permutation.size());
|
||||
for (int64_t axis : permutation)
|
||||
permutedShape.push_back(shape[axis]);
|
||||
return permutedShape;
|
||||
}
|
||||
|
||||
static Value buildLoopSoftmaxSlice(Value input,
|
||||
Value accumulator,
|
||||
RankedTensorType inputType,
|
||||
@@ -36,7 +26,7 @@ static Value buildLoopSoftmaxSlice(Value input,
|
||||
|
||||
SmallVector<OpFoldResult> offsets;
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> strides = getUnitStrides(rewriter, rank);
|
||||
offsets.reserve(rank);
|
||||
sizes.reserve(rank);
|
||||
|
||||
@@ -62,9 +52,10 @@ static Value buildLoopSoftmaxNest(Value input,
|
||||
if (axis == inputType.getRank() - 1)
|
||||
return buildLoopSoftmaxSlice(input, accumulator, inputType, outerIndices, rewriter, loc);
|
||||
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
Value cUpper = arith::ConstantIndexOp::create(rewriter, loc, inputType.getDimSize(axis));
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
||||
Value cUpper = getOrCreateIndexConstant(rewriter, anchorOp, inputType.getDimSize(axis));
|
||||
|
||||
auto loop = scf::ForOp::create(rewriter, loc, c0, cUpper, c1, ValueRange {accumulator});
|
||||
rewriter.setInsertionPointToStart(loop.getBody());
|
||||
@@ -110,44 +101,29 @@ struct SoftmaxToSpatialCompute : OpConversionPattern<ONNXSoftmaxOp> {
|
||||
if (!inputType || !inputType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
int64_t axis = normalizeAxis(softmaxOp.getAxis(), inputType.getRank());
|
||||
if (axis < 0 || axis >= inputType.getRank())
|
||||
auto axis = normalizeAxisChecked(softmaxOp.getAxis(), inputType.getRank());
|
||||
if (failed(axis))
|
||||
return failure();
|
||||
|
||||
Value input = adaptor.getInput();
|
||||
Value result;
|
||||
if (axis == inputType.getRank() - 1) {
|
||||
if (*axis == inputType.getRank() - 1) {
|
||||
result = createLoopSoftmaxCompute(input, rewriter, softmaxOp.getLoc());
|
||||
}
|
||||
else {
|
||||
SmallVector<int64_t> permutation;
|
||||
permutation.reserve(inputType.getRank());
|
||||
for (int64_t dim = 0; dim < inputType.getRank(); ++dim)
|
||||
if (dim != axis)
|
||||
if (dim != *axis)
|
||||
permutation.push_back(dim);
|
||||
permutation.push_back(axis);
|
||||
|
||||
SmallVector<int64_t> inversePermutation(inputType.getRank());
|
||||
for (auto [newIndex, oldIndex] : llvm::enumerate(permutation))
|
||||
inversePermutation[oldIndex] = static_cast<int64_t>(newIndex);
|
||||
permutation.push_back(*axis);
|
||||
SmallVector<int64_t> inversePermutation = invertPermutation(permutation);
|
||||
|
||||
auto transposedType = RankedTensorType::get(
|
||||
permuteShape(inputType.getShape(), permutation), inputType.getElementType(), inputType.getEncoding());
|
||||
auto preTransposeCompute =
|
||||
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {transposedType}, {}, input, [&](Value x) {
|
||||
Value transposed = ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), transposedType, x, rewriter.getI64ArrayAttr(permutation));
|
||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||
});
|
||||
Value transposedInput = preTransposeCompute.getResult(0);
|
||||
Value transposedInput = transposeMaybeInCompute(input, transposedType, permutation, rewriter, softmaxOp.getLoc());
|
||||
Value transposedResult = createLoopSoftmaxCompute(transposedInput, rewriter, softmaxOp.getLoc());
|
||||
auto postTransposeCompute =
|
||||
createSpatCompute<1>(rewriter, softmaxOp.getLoc(), TypeRange {inputType}, {}, transposedResult, [&](Value x) {
|
||||
Value transposed = ONNXTransposeOp::create(
|
||||
rewriter, softmaxOp.getLoc(), inputType, x, rewriter.getI64ArrayAttr(inversePermutation));
|
||||
spatial::SpatYieldOp::create(rewriter, softmaxOp.getLoc(), transposed);
|
||||
});
|
||||
result = postTransposeCompute.getResult(0);
|
||||
result = transposeMaybeInCompute(transposedResult, inputType, inversePermutation, rewriter, softmaxOp.getLoc());
|
||||
}
|
||||
|
||||
rewriter.replaceOp(softmaxOp, result);
|
||||
|
||||
@@ -0,0 +1,288 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isWeightMaterializationHelperUser(Operation* op) {
|
||||
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, linalg::TransposeOp>(op);
|
||||
}
|
||||
|
||||
static bool canPromoteInputBlockArgument(BlockArgument arg) {
|
||||
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
|
||||
}
|
||||
|
||||
static bool canPromoteInputBlockArgument(std::optional<BlockArgument> arg) {
|
||||
return arg && canPromoteInputBlockArgument(*arg);
|
||||
}
|
||||
|
||||
static bool isDirectConstantValue(Value value) {
|
||||
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
||||
}
|
||||
|
||||
struct PromotedOperands {
|
||||
SmallVector<bool> promoteInput;
|
||||
SmallVector<Value> newWeights;
|
||||
SmallVector<Value> newInputs;
|
||||
SmallVector<Type> newInputTypes;
|
||||
SmallVector<Location> newInputLocs;
|
||||
};
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
continue;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
static FailureOr<PromotedOperands> computePromotedOperands(ComputeOpTy compute) {
|
||||
PromotedOperands promoted;
|
||||
promoted.promoteInput.assign(compute.getInputs().size(), false);
|
||||
promoted.newWeights.append(compute.getWeights().begin(), compute.getWeights().end());
|
||||
promoted.newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||
promoted.newInputs.reserve(compute.getInputs().size());
|
||||
promoted.newInputTypes.reserve(compute.getInputs().size());
|
||||
promoted.newInputLocs.reserve(compute.getInputs().size());
|
||||
|
||||
bool needsRewrite = false;
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
goto keep_input;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(compute.getInputArgument(inputIdx)))
|
||||
goto keep_input;
|
||||
promoted.promoteInput[inputIdx] = true;
|
||||
promoted.newWeights.push_back(input);
|
||||
needsRewrite = true;
|
||||
continue;
|
||||
|
||||
keep_input:
|
||||
promoted.newInputs.push_back(input);
|
||||
promoted.newInputTypes.push_back(input.getType());
|
||||
promoted.newInputLocs.push_back(input.getLoc());
|
||||
}
|
||||
|
||||
if (!needsRewrite)
|
||||
return failure();
|
||||
return promoted;
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
static LogicalResult mapPromotedInputArguments(ComputeOpTy compute,
|
||||
const PromotedOperands& promoted,
|
||||
IRRewriter& bodyRewriter,
|
||||
IRMapping& mapper,
|
||||
std::function<std::optional<BlockArgument>(size_t)> getNewInputArg,
|
||||
PatternRewriter& rewriter) {
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
auto oldArg = compute.getInputArgument(oldInputIdx);
|
||||
if (!oldArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing input block argument during rewrite");
|
||||
if (!promoted.promoteInput[oldInputIdx]) {
|
||||
auto newInputArg = getNewInputArg(newInputIdx++);
|
||||
if (!newInputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten input block argument");
|
||||
mapper.map(*oldArg, *newInputArg);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
||||
mapper.map(*oldArg, *clonedValue);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
||||
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
||||
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
|
||||
auto promoted = computePromotedOperands(compute);
|
||||
if (failed(promoted))
|
||||
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
auto newCompute = spatial::SpatCompute::create(
|
||||
rewriter, compute.getLoc(), compute.getResultTypes(), promoted->newWeights, promoted->newInputs);
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
for (Value weight : promoted->newWeights) {
|
||||
newBlockArgTypes.push_back(weight.getType());
|
||||
newBlockArgLocs.push_back(weight.getLoc());
|
||||
}
|
||||
llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRRewriter bodyRewriter(rewriter.getContext());
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
||||
auto oldWeightArg = compute.getWeightArgument(weightIndex);
|
||||
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
|
||||
if (!oldWeightArg || !newWeightArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
|
||||
mapper.map(*oldWeightArg, *newWeightArg);
|
||||
}
|
||||
if (failed(mapPromotedInputArguments(
|
||||
compute,
|
||||
*promoted,
|
||||
bodyRewriter,
|
||||
mapper,
|
||||
[&](size_t index) { return newCompute.getInputArgument(index); },
|
||||
rewriter)))
|
||||
return failure();
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator())
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
SmallVector<Value> newYieldOperands;
|
||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
||||
for (Value operand : oldYield.getOutputs()) {
|
||||
auto mapped = mapper.lookupOrNull(operand);
|
||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
||||
|
||||
rewriter.replaceOp(compute, newCompute.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
|
||||
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
|
||||
auto promoted = computePromotedOperands(compute);
|
||||
if (failed(promoted))
|
||||
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatComputeBatch::create(rewriter,
|
||||
compute.getLoc(),
|
||||
compute.getResultTypes(),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
||||
promoted->newWeights,
|
||||
promoted->newInputs);
|
||||
auto laneArg = compute.getLaneArgument();
|
||||
if (!laneArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
newBlockArgTypes.reserve(1 + promoted->newWeights.size() + promoted->newInputTypes.size()
|
||||
+ compute.getNumResults());
|
||||
newBlockArgLocs.reserve(1 + promoted->newWeights.size() + promoted->newInputLocs.size() + compute.getNumResults());
|
||||
newBlockArgTypes.push_back(laneArg->getType());
|
||||
newBlockArgLocs.push_back(laneArg->getLoc());
|
||||
for (Value weight : promoted->newWeights) {
|
||||
newBlockArgTypes.push_back(weight.getType());
|
||||
newBlockArgLocs.push_back(weight.getLoc());
|
||||
}
|
||||
llvm::append_range(newBlockArgTypes, promoted->newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, promoted->newInputLocs);
|
||||
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
|
||||
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||
if (!outputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument");
|
||||
newBlockArgTypes.push_back(resultType);
|
||||
newBlockArgLocs.push_back(outputArg->getLoc());
|
||||
}
|
||||
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
&newCompute.getBody(), newCompute.getBody().end(), TypeRange(newBlockArgTypes), newBlockArgLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(promoted->newWeights.size()), static_cast<int>(promoted->newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRRewriter bodyRewriter(rewriter.getContext());
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
auto newLaneArg = newCompute.getLaneArgument();
|
||||
if (!newLaneArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch lane block argument");
|
||||
mapper.map(*laneArg, *newLaneArg);
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
||||
auto oldWeightArg = compute.getWeightArgument(weightIndex);
|
||||
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
|
||||
if (!oldWeightArg || !newWeightArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
|
||||
mapper.map(*oldWeightArg, *newWeightArg);
|
||||
}
|
||||
if (failed(mapPromotedInputArguments(
|
||||
compute,
|
||||
*promoted,
|
||||
bodyRewriter,
|
||||
mapper,
|
||||
[&](size_t index) { return newCompute.getInputArgument(index); },
|
||||
rewriter)))
|
||||
return failure();
|
||||
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
|
||||
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||
if (!outputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
|
||||
mapper.map(*outputArg,
|
||||
newBlock->getArgument(1 + promoted->newWeights.size() + promoted->newInputs.size() + resultIndex));
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
rewriter.replaceOp(compute, newCompute.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateWeightPromotionPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
|
||||
}
|
||||
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) {
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
|
||||
markWeightAlways(constantOp);
|
||||
});
|
||||
}
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+2
-3
@@ -1,6 +1,5 @@
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PrePatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@@ -12,7 +11,7 @@ namespace {
|
||||
|
||||
} // namespace
|
||||
|
||||
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
|
||||
void populateGeneratedPrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx) {
|
||||
patterns.add<onnxToArithConstant>(ctx);
|
||||
patterns.add<convAddToConvWithBiasLeft>(ctx);
|
||||
patterns.add<convAddToConvWithBiasRight>(ctx);
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/ComputeRegionBuilder.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -20,7 +20,7 @@ struct Concat : public OpConversionPattern<ONNXConcatOp> {
|
||||
auto inputs = adaptor.getInputs();
|
||||
int64_t axis = adaptor.getAxis();
|
||||
|
||||
if (llvm::all_of(inputs, isHostFoldableValue)) {
|
||||
if (llvm::all_of(inputs, isCompileTimeComputable)) {
|
||||
rewriter.replaceOp(maxpoolOp, createSpatConcat(rewriter, maxpoolOp.getLoc(), axis, inputs));
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -15,24 +15,6 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
||||
|
||||
static int64_t normalizeIndex(int64_t index, int64_t dimSize) { return index >= 0 ? index : dimSize + index; }
|
||||
|
||||
static Value
|
||||
extractSliceAt(Value input, int64_t axis, int64_t offset, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
|
||||
sizes.reserve(inputType.getRank());
|
||||
for (int64_t dim : inputType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(1);
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, input, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
static Value concatGatherSlices(Value data,
|
||||
int64_t axis,
|
||||
ArrayRef<int64_t> indices,
|
||||
@@ -45,7 +27,7 @@ static Value concatGatherSlices(Value data,
|
||||
int64_t normalizedIndex = normalizeIndex(index, axisDim);
|
||||
if (normalizedIndex < 0 || normalizedIndex >= axisDim)
|
||||
return {};
|
||||
slices.push_back(extractSliceAt(data, axis, normalizedIndex, rewriter, loc));
|
||||
slices.push_back(extractAxisSlice(rewriter, loc, data, axis, normalizedIndex, /*size=*/1));
|
||||
}
|
||||
if (slices.empty())
|
||||
return {};
|
||||
@@ -96,11 +78,11 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
|
||||
return failure();
|
||||
|
||||
int64_t rank = dataType.getRank();
|
||||
int64_t axis = normalizeAxis(gatherOp.getAxis(), rank);
|
||||
if (axis < 0 || axis >= rank)
|
||||
auto axis = normalizeAxisChecked(gatherOp.getAxis(), rank);
|
||||
if (failed(axis))
|
||||
return failure();
|
||||
|
||||
int64_t axisDim = dataType.getShape()[axis];
|
||||
int64_t axisDim = dataType.getShape()[*axis];
|
||||
if (axisDim <= 0)
|
||||
return failure();
|
||||
|
||||
@@ -116,7 +98,7 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
|
||||
[&](Value data) -> LogicalResult {
|
||||
Value result;
|
||||
if (indicesType.getRank() == 1) {
|
||||
result = concatGatherSlices(data, axis, flatIndices, axisDim, rewriter, loc);
|
||||
result = concatGatherSlices(data, *axis, flatIndices, axisDim, rewriter, loc);
|
||||
}
|
||||
else if (indicesType.getRank() == 2) {
|
||||
int64_t rowCount = indicesType.getShape()[0];
|
||||
@@ -125,12 +107,13 @@ struct Gather : OpConversionPattern<ONNXGatherOp> {
|
||||
rows.reserve(rowCount);
|
||||
for (int64_t row = 0; row < rowCount; ++row) {
|
||||
ArrayRef<int64_t> rowIndices(flatIndices.data() + row * rowWidth, rowWidth);
|
||||
Value gatheredRow = concatGatherSlices(data, axis, rowIndices, axisDim, rewriter, loc);
|
||||
Value gatheredRow =
|
||||
concatGatherSlices(data, *axis, rowIndices, axisDim, rewriter, loc);
|
||||
if (!gatheredRow)
|
||||
return failure();
|
||||
rows.push_back(addLeadingGatherDim(gatheredRow, axis, rewriter, loc));
|
||||
rows.push_back(addLeadingGatherDim(gatheredRow, *axis, rewriter, loc));
|
||||
}
|
||||
result = createSpatConcat(rewriter, loc, /*axis=*/axis, rows);
|
||||
result = createSpatConcat(rewriter, loc, /*axis=*/*axis, rows);
|
||||
}
|
||||
else {
|
||||
return failure();
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -14,10 +14,6 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static bool haveStaticPositiveShape(ArrayRef<int64_t> shape) {
|
||||
return llvm::all_of(shape, [](int64_t dim) { return dim > 0; });
|
||||
}
|
||||
|
||||
static bool inferCollapseReassociation(ArrayRef<int64_t> sourceShape,
|
||||
ArrayRef<int64_t> resultShape,
|
||||
SmallVector<ReassociationIndices>& reassociation) {
|
||||
@@ -106,7 +102,7 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
auto resultType = dyn_cast<RankedTensorType>(reshapeOp.getReshaped().getType());
|
||||
if (!sourceType || !resultType || !sourceType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
if (!haveStaticPositiveShape(sourceType.getShape()) || !haveStaticPositiveShape(resultType.getShape()))
|
||||
if (!hasStaticPositiveShape(sourceType) || !hasStaticPositiveShape(resultType))
|
||||
return failure();
|
||||
|
||||
if (sourceType == resultType) {
|
||||
@@ -115,17 +111,9 @@ struct Reshape : OpConversionPattern<ONNXReshapeOp> {
|
||||
}
|
||||
|
||||
auto replaceWithReshape = [&](auto buildReshape) -> LogicalResult {
|
||||
if (isHostFoldableValue(adaptor.getData())) {
|
||||
rewriter.replaceOp(reshapeOp, buildReshape(adaptor.getData()));
|
||||
return success();
|
||||
}
|
||||
|
||||
auto computeOp = createSpatCompute<1>(
|
||||
rewriter, reshapeOp.getLoc(), TypeRange {resultType}, {}, adaptor.getData(), [&](Value data) {
|
||||
Value reshaped = buildReshape(data);
|
||||
spatial::SpatYieldOp::create(rewriter, reshapeOp.getLoc(), reshaped);
|
||||
});
|
||||
rewriter.replaceOp(reshapeOp, computeOp.getResults());
|
||||
Value reshaped =
|
||||
materializeOrComputeUnary(adaptor.getData(), resultType, rewriter, reshapeOp.getLoc(), buildReshape);
|
||||
rewriter.replaceOp(reshapeOp, reshaped);
|
||||
return success();
|
||||
};
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -17,9 +17,10 @@ namespace {
|
||||
|
||||
static Value buildNearestAsymmetricIndex(
|
||||
Value outputIndex, int64_t inputDim, int64_t outputDim, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
Value cInputDim = arith::ConstantIndexOp::create(rewriter, loc, inputDim);
|
||||
Value cOutputDim = arith::ConstantIndexOp::create(rewriter, loc, outputDim);
|
||||
Value cInputDimLast = arith::ConstantIndexOp::create(rewriter, loc, inputDim - 1);
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
Value cInputDim = getOrCreateIndexConstant(rewriter, anchorOp, inputDim);
|
||||
Value cOutputDim = getOrCreateIndexConstant(rewriter, anchorOp, outputDim);
|
||||
Value cInputDimLast = getOrCreateIndexConstant(rewriter, anchorOp, inputDim - 1);
|
||||
Value scaledIndex = arith::MulIOp::create(rewriter, loc, outputIndex, cInputDim);
|
||||
Value inputIndex = arith::DivUIOp::create(rewriter, loc, scaledIndex, cOutputDim);
|
||||
return arith::MinUIOp::create(rewriter, loc, inputIndex, cInputDimLast);
|
||||
@@ -37,12 +38,13 @@ static Value buildNearestResizeLoop(Value input,
|
||||
SmallVector<OpFoldResult> unitSizes(resultType.getRank(), rewriter.getIndexAttr(1));
|
||||
SmallVector<OpFoldResult> unitStrides(resultType.getRank(), rewriter.getIndexAttr(1));
|
||||
|
||||
Value c0 = arith::ConstantIndexOp::create(rewriter, loc, 0);
|
||||
Value c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
|
||||
Value cOutputN = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(0));
|
||||
Value cOutputC = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(1));
|
||||
Value cOutputH = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(2));
|
||||
Value cOutputW = arith::ConstantIndexOp::create(rewriter, loc, resultType.getDimSize(3));
|
||||
Operation* anchorOp = rewriter.getInsertionBlock()->getParentOp();
|
||||
Value c0 = getOrCreateIndexConstant(rewriter, anchorOp, 0);
|
||||
Value c1 = getOrCreateIndexConstant(rewriter, anchorOp, 1);
|
||||
Value cOutputN = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(0));
|
||||
Value cOutputC = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(1));
|
||||
Value cOutputH = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(2));
|
||||
Value cOutputW = getOrCreateIndexConstant(rewriter, anchorOp, resultType.getDimSize(3));
|
||||
|
||||
Value outputInit = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), elemType);
|
||||
|
||||
@@ -58,24 +60,21 @@ static Value buildNearestResizeLoop(Value input,
|
||||
|
||||
Value outputC = channelLoop.getInductionVar();
|
||||
Value outputChannelAcc = channelLoop.getRegionIterArgs().front();
|
||||
Value inputC =
|
||||
buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
|
||||
Value inputC = buildNearestAsymmetricIndex(outputC, inputType.getDimSize(1), resultType.getDimSize(1), rewriter, loc);
|
||||
|
||||
auto heightLoop = scf::ForOp::create(rewriter, loc, c0, cOutputH, c1, ValueRange {outputChannelAcc});
|
||||
rewriter.setInsertionPointToStart(heightLoop.getBody());
|
||||
|
||||
Value outputH = heightLoop.getInductionVar();
|
||||
Value outputHeightAcc = heightLoop.getRegionIterArgs().front();
|
||||
Value inputH =
|
||||
buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
|
||||
Value inputH = buildNearestAsymmetricIndex(outputH, inputType.getDimSize(2), resultType.getDimSize(2), rewriter, loc);
|
||||
|
||||
auto widthLoop = scf::ForOp::create(rewriter, loc, c0, cOutputW, c1, ValueRange {outputHeightAcc});
|
||||
rewriter.setInsertionPointToStart(widthLoop.getBody());
|
||||
|
||||
Value outputW = widthLoop.getInductionVar();
|
||||
Value outputWidthAcc = widthLoop.getRegionIterArgs().front();
|
||||
Value inputW =
|
||||
buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
|
||||
Value inputW = buildNearestAsymmetricIndex(outputW, inputType.getDimSize(3), resultType.getDimSize(3), rewriter, loc);
|
||||
|
||||
SmallVector<OpFoldResult> inputOffsets = {inputN, inputC, inputH, inputW};
|
||||
Value inputSlice =
|
||||
@@ -114,8 +113,8 @@ struct Resize : OpConversionPattern<ONNXResizeOp> {
|
||||
|
||||
if (resizeOp.getMode() != "nearest" || resizeOp.getCoordinateTransformationMode() != "asymmetric"
|
||||
|| resizeOp.getNearestMode() != "floor")
|
||||
return rewriter.notifyMatchFailure(
|
||||
resizeOp, "resize lowering currently supports only nearest + asymmetric + floor.");
|
||||
return rewriter.notifyMatchFailure(resizeOp,
|
||||
"resize lowering currently supports only nearest + asymmetric + floor.");
|
||||
|
||||
if (llvm::any_of(inputType.getShape(), [](int64_t dim) { return dim <= 0; })
|
||||
|| llvm::any_of(resultType.getShape(), [](int64_t dim) { return dim <= 0; }))
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/ConversionPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/HostFoldability.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/CompileTime.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -12,25 +12,6 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static int64_t normalizeAxis(int64_t axis, int64_t rank) { return axis >= 0 ? axis : rank + axis; }
|
||||
|
||||
static Value extractSliceAt(
|
||||
Value input, int64_t axis, int64_t offset, int64_t size, ConversionPatternRewriter& rewriter, Location loc) {
|
||||
auto inputType = cast<RankedTensorType>(input.getType());
|
||||
SmallVector<OpFoldResult> offsets(inputType.getRank(), rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
SmallVector<OpFoldResult> strides(inputType.getRank(), rewriter.getIndexAttr(1));
|
||||
sizes.reserve(inputType.getRank());
|
||||
for (int64_t dim : inputType.getShape())
|
||||
sizes.push_back(rewriter.getIndexAttr(dim));
|
||||
offsets[axis] = rewriter.getIndexAttr(offset);
|
||||
sizes[axis] = rewriter.getIndexAttr(size);
|
||||
SmallVector<int64_t> resultShape(inputType.getShape());
|
||||
resultShape[axis] = size;
|
||||
auto resultType = RankedTensorType::get(resultShape, inputType.getElementType());
|
||||
return tensor::ExtractSliceOp::create(rewriter, loc, resultType, input, offsets, sizes, strides);
|
||||
}
|
||||
|
||||
struct Split : OpConversionPattern<ONNXSplitOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
@@ -41,8 +22,8 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
|
||||
return failure();
|
||||
|
||||
int64_t rank = inputType.getRank();
|
||||
int64_t axis = normalizeAxis(splitOp.getAxis(), rank);
|
||||
if (axis < 0 || axis >= rank)
|
||||
auto axis = normalizeAxisChecked(splitOp.getAxis(), rank);
|
||||
if (failed(axis))
|
||||
return failure();
|
||||
|
||||
SmallVector<Value> outputs;
|
||||
@@ -58,12 +39,12 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
|
||||
if (!resultType || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
resultTypes.push_back(resultType);
|
||||
sliceSizes.push_back(resultType.getShape()[axis]);
|
||||
sliceSizes.push_back(resultType.getShape()[*axis]);
|
||||
}
|
||||
|
||||
if (isHostFoldableValue(adaptor.getInput())) {
|
||||
if (isCompileTimeComputable(adaptor.getInput())) {
|
||||
for (int64_t sliceSize : sliceSizes) {
|
||||
outputs.push_back(extractSliceAt(adaptor.getInput(), axis, offset, sliceSize, rewriter, splitOp.getLoc()));
|
||||
outputs.push_back(extractAxisSlice(rewriter, splitOp.getLoc(), adaptor.getInput(), *axis, offset, sliceSize));
|
||||
offset += sliceSize;
|
||||
}
|
||||
rewriter.replaceOp(splitOp, outputs);
|
||||
@@ -76,7 +57,8 @@ struct Split : OpConversionPattern<ONNXSplitOp> {
|
||||
runtimeOutputs.reserve(resultTypes.size());
|
||||
int64_t runtimeOffset = 0;
|
||||
for (int64_t sliceSize : sliceSizes) {
|
||||
runtimeOutputs.push_back(extractSliceAt(input, axis, runtimeOffset, sliceSize, rewriter, splitOp.getLoc()));
|
||||
runtimeOutputs.push_back(
|
||||
extractAxisSlice(rewriter, splitOp.getLoc(), input, *axis, runtimeOffset, sliceSize));
|
||||
runtimeOffset += sliceSize;
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, splitOp.getLoc(), runtimeOutputs);
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Patterns.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static Value createTransposeInit(Value input,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<int64_t> permutation,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
SmallVector<OpFoldResult> sizes;
|
||||
sizes.reserve(resultType.getRank());
|
||||
for (auto [resultDim, sourceDim] : llvm::zip_equal(resultType.getShape(), permutation)) {
|
||||
if (!ShapedType::isDynamic(resultDim)) {
|
||||
sizes.push_back(rewriter.getIndexAttr(resultDim));
|
||||
continue;
|
||||
}
|
||||
sizes.push_back(tensor::DimOp::create(rewriter, loc, input, sourceDim).getResult());
|
||||
}
|
||||
return tensor::EmptyOp::create(rewriter, loc, sizes, resultType.getElementType()).getResult();
|
||||
}
|
||||
|
||||
static FailureOr<Value> materializeTransposedConstant(Value input,
|
||||
RankedTensorType resultType,
|
||||
ArrayRef<int64_t> permutation,
|
||||
ConversionPatternRewriter& rewriter,
|
||||
Location loc) {
|
||||
auto denseAttr = getHostConstDenseElementsAttr(input);
|
||||
if (!denseAttr)
|
||||
return failure();
|
||||
|
||||
auto inputType = dyn_cast<RankedTensorType>(denseAttr.getType());
|
||||
if (!inputType || !inputType.hasStaticShape() || !resultType.hasStaticShape()
|
||||
|| inputType.getRank() != resultType.getRank()
|
||||
|| static_cast<int64_t>(permutation.size()) != inputType.getRank()) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (denseAttr.isSplat())
|
||||
return getOrCreateConstant(rewriter,
|
||||
rewriter.getInsertionBlock()->getParentOp(),
|
||||
DenseElementsAttr::get(resultType, denseAttr.getSplatValue<Attribute>()),
|
||||
resultType);
|
||||
|
||||
SmallVector<Attribute> inputValues(denseAttr.getValues<Attribute>());
|
||||
SmallVector<Attribute> resultValues(inputValues.size());
|
||||
SmallVector<int64_t> inputStrides = computeRowMajorStrides(inputType.getShape());
|
||||
SmallVector<int64_t> resultStrides = computeRowMajorStrides(resultType.getShape());
|
||||
SmallVector<int64_t> inputIndices(inputType.getRank(), 0);
|
||||
|
||||
for (auto [linearIndex, value] : llvm::enumerate(inputValues)) {
|
||||
int64_t remaining = static_cast<int64_t>(linearIndex);
|
||||
for (int64_t dim = 0; dim < inputType.getRank(); ++dim) {
|
||||
inputIndices[dim] = inputStrides.empty() ? 0 : remaining / inputStrides[dim];
|
||||
remaining = inputStrides.empty() ? 0 : remaining % inputStrides[dim];
|
||||
}
|
||||
|
||||
int64_t resultLinearIndex = 0;
|
||||
for (int64_t dim = 0; dim < resultType.getRank(); ++dim)
|
||||
resultLinearIndex += inputIndices[permutation[dim]] * resultStrides[dim];
|
||||
|
||||
resultValues[resultLinearIndex] = value;
|
||||
}
|
||||
|
||||
return getOrCreateConstant(rewriter,
|
||||
rewriter.getInsertionBlock()->getParentOp(),
|
||||
DenseElementsAttr::get(resultType, resultValues),
|
||||
resultType);
|
||||
}
|
||||
|
||||
struct TransposeToLinalgTranspose : OpConversionPattern<ONNXTransposeOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(ONNXTransposeOp transposeOp,
|
||||
ONNXTransposeOpAdaptor adaptor,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
auto inputType = dyn_cast<RankedTensorType>(adaptor.getData().getType());
|
||||
auto resultType = dyn_cast<RankedTensorType>(transposeOp.getResult().getType());
|
||||
if (!inputType || !resultType)
|
||||
return failure();
|
||||
|
||||
auto permutation = getTransposePermutationChecked(transposeOp.getPermAttr(), inputType.getRank());
|
||||
if (failed(permutation))
|
||||
return failure();
|
||||
if (isCompileTimeComputable(adaptor.getData())) {
|
||||
auto constantTranspose =
|
||||
materializeTransposedConstant(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
|
||||
if (succeeded(constantTranspose)) {
|
||||
rewriter.replaceOp(transposeOp, *constantTranspose);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
Value init = createTransposeInit(adaptor.getData(), resultType, *permutation, rewriter, transposeOp.getLoc());
|
||||
Value transposed =
|
||||
linalg::TransposeOp::create(rewriter, transposeOp.getLoc(), adaptor.getData(), init, *permutation).getResult()[0];
|
||||
rewriter.replaceOp(transposeOp, transposed);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateTransposePatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<TransposeToLinalgTranspose>(ctx);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,286 +0,0 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Common/IR/WeightUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/Common/WeightMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/ONNXToSpatial/PostPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
namespace {
|
||||
|
||||
static bool isWeightMaterializationHelperUser(Operation* op) {
|
||||
return isa<tensor::ExtractSliceOp, tensor::ExpandShapeOp, tensor::CollapseShapeOp, ONNXTransposeOp>(op);
|
||||
}
|
||||
|
||||
static bool canPromoteInputBlockArgument(BlockArgument arg) {
|
||||
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
|
||||
}
|
||||
|
||||
static bool isDirectConstantValue(Value value) {
|
||||
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
||||
Block& block = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (inputIdx >= block.getNumArguments())
|
||||
continue;
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(block.getArgument(inputIdx)))
|
||||
continue;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Collapses one-lane batches so later phases do not carry batch-only structure unnecessarily.
|
||||
struct FoldSingleLaneComputeBatchPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch batchOp, PatternRewriter& rewriter) const override {
|
||||
if (batchOp.getLaneCount() != 1)
|
||||
return rewriter.notifyMatchFailure(batchOp, "requires a single lane");
|
||||
|
||||
auto loc = batchOp.getLoc();
|
||||
rewriter.setInsertionPoint(batchOp);
|
||||
auto computeOp =
|
||||
spatial::SpatCompute::create(rewriter, loc, batchOp.getResultTypes(), batchOp.getWeights(), batchOp.getInputs());
|
||||
computeOp.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(batchOp.getWeights().size()), static_cast<int>(batchOp.getInputs().size())});
|
||||
|
||||
Block& templateBlock = batchOp.getBody().front();
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
blockArgTypes.reserve(templateBlock.getNumArguments());
|
||||
blockArgLocs.reserve(templateBlock.getNumArguments());
|
||||
for (BlockArgument arg : templateBlock.getArguments()) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(loc);
|
||||
}
|
||||
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&computeOp.getBody(), computeOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
IRMapping mapper;
|
||||
for (auto [oldArg, newArg] : llvm::zip(templateBlock.getArguments(), newBlock->getArguments()))
|
||||
mapper.map(oldArg, newArg);
|
||||
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : templateBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
batchOp->replaceAllUsesWith(computeOp->getResults());
|
||||
rewriter.eraseOp(batchOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Promotes foldable helper chains from runtime inputs to weights to avoid artificial compute inputs.
|
||||
struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCompute> {
|
||||
using OpRewritePattern<spatial::SpatCompute>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatCompute compute, PatternRewriter& rewriter) const override {
|
||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||
bool needsRewrite = false;
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (inputIdx >= oldBlock.getNumArguments())
|
||||
continue;
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
}
|
||||
if (!needsRewrite)
|
||||
return rewriter.notifyMatchFailure(compute, "no weight-like inputs to promote");
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
|
||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||
SmallVector<Value> newInputs;
|
||||
SmallVector<Type> newInputTypes;
|
||||
SmallVector<Location> newInputLocs;
|
||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||
newInputs.reserve(compute.getInputs().size());
|
||||
newInputTypes.reserve(compute.getInputs().size());
|
||||
newInputLocs.reserve(compute.getInputs().size());
|
||||
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (promoteInput[inputIdx]) {
|
||||
newWeights.push_back(input);
|
||||
continue;
|
||||
}
|
||||
newInputs.push_back(input);
|
||||
newInputTypes.push_back(input.getType());
|
||||
newInputLocs.push_back(input.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatCompute::create(rewriter, compute.getLoc(), compute.getResultTypes(), newWeights, newInputs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRRewriter bodyRewriter(rewriter.getContext());
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator())
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
SmallVector<Value> newYieldOperands;
|
||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
||||
for (Value operand : oldYield.getOutputs()) {
|
||||
auto mapped = mapper.lookupOrNull(operand);
|
||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
||||
|
||||
rewriter.replaceOp(compute, newCompute.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Promotes foldable batch helper chains to weights while preserving compact compute_batch IR.
|
||||
struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::SpatComputeBatch> {
|
||||
using OpRewritePattern<spatial::SpatComputeBatch>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatComputeBatch compute, PatternRewriter& rewriter) const override {
|
||||
SmallVector<bool> promoteInput(compute.getInputs().size(), false);
|
||||
bool needsRewrite = false;
|
||||
Block& oldBlock = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (inputIdx >= oldBlock.getNumArguments())
|
||||
continue;
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
if (isDirectConstantValue(input) && !canPromoteInputBlockArgument(oldBlock.getArgument(inputIdx)))
|
||||
continue;
|
||||
promoteInput[inputIdx] = true;
|
||||
needsRewrite = true;
|
||||
}
|
||||
if (!needsRewrite)
|
||||
return rewriter.notifyMatchFailure(compute, "no weight-like batch inputs to promote");
|
||||
|
||||
rewriter.setInsertionPointAfter(compute);
|
||||
|
||||
SmallVector<Value> newWeights(compute.getWeights().begin(), compute.getWeights().end());
|
||||
SmallVector<Value> newInputs;
|
||||
SmallVector<Type> newInputTypes;
|
||||
SmallVector<Location> newInputLocs;
|
||||
newWeights.reserve(compute.getWeights().size() + compute.getInputs().size());
|
||||
newInputs.reserve(compute.getInputs().size());
|
||||
newInputTypes.reserve(compute.getInputs().size());
|
||||
newInputLocs.reserve(compute.getInputs().size());
|
||||
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (promoteInput[inputIdx]) {
|
||||
newWeights.push_back(input);
|
||||
continue;
|
||||
}
|
||||
newInputs.push_back(input);
|
||||
newInputTypes.push_back(input.getType());
|
||||
newInputLocs.push_back(input.getLoc());
|
||||
}
|
||||
|
||||
auto newCompute =
|
||||
spatial::SpatComputeBatch::create(rewriter,
|
||||
compute.getLoc(),
|
||||
compute.getResultTypes(),
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
||||
newWeights,
|
||||
newInputs);
|
||||
auto* newBlock =
|
||||
rewriter.createBlock(&newCompute.getBody(), newCompute.getBody().end(), newInputTypes, newInputLocs);
|
||||
newCompute.getProperties().setOperandSegmentSizes(
|
||||
{static_cast<int>(newWeights.size()), static_cast<int>(newInputs.size())});
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRRewriter bodyRewriter(rewriter.getContext());
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, oldArg] : llvm::enumerate(oldBlock.getArguments())) {
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newBlock->getArgument(newInputIdx++));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(compute.getInputs()[oldInputIdx], bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator())
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
SmallVector<Value> newYieldOperands;
|
||||
newYieldOperands.reserve(oldYield.getOutputs().size());
|
||||
for (Value operand : oldYield.getOutputs()) {
|
||||
auto mapped = mapper.lookupOrNull(operand);
|
||||
newYieldOperands.push_back(mapped ? cast<Value>(mapped) : operand);
|
||||
}
|
||||
spatial::SpatYieldOp::create(rewriter, oldYield.getLoc(), newYieldOperands);
|
||||
|
||||
rewriter.replaceOp(compute, newCompute.getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateEarlyPostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<FoldSingleLaneComputeBatchPattern>(ctx);
|
||||
}
|
||||
|
||||
void populatePostPatterns(RewritePatternSet& patterns, MLIRContext* ctx) {
|
||||
patterns.add<PromoteWeightLikeComputeInputsPattern, PromoteWeightLikeComputeBatchInputsPattern>(ctx);
|
||||
}
|
||||
|
||||
void annotateWeightsConstants(func::FuncOp funcOp) {
|
||||
funcOp.walk([&](arith::ConstantOp constantOp) {
|
||||
if (hasOnlySpatialMvmVmmWeightUses(constantOp.getResult()))
|
||||
markWeightAlways(constantOp);
|
||||
});
|
||||
}
|
||||
|
||||
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp) { return batchOp.getLaneCount() == 1; }
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp) { return hasPromotableWeightLikeInputs(computeOp); }
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,22 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
bool requiresEarlyPostRewrite(spatial::SpatComputeBatch batchOp);
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatCompute computeOp);
|
||||
|
||||
bool requiresPostRewrite(spatial::SpatComputeBatch computeOp);
|
||||
|
||||
void populateEarlyPostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void populatePostPatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
void annotateWeightsConstants(mlir::func::FuncOp funcOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,10 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populatePrePatterns(mlir::RewritePatternSet& patterns, mlir::MLIRContext* ctx);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,11 +1,15 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/IR/BatchCoreUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
@@ -15,7 +19,11 @@ using namespace onnx_mlir::pim;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||
static bool isUsedOnlyAsExplicitHostOperand(Value value) {
|
||||
return !value.use_empty() && llvm::all_of(value.getUses(), [](OpOperand& use) {
|
||||
return isExplicitDevToHostTargetOperand(use.getOwner(), use.getOperandNumber());
|
||||
});
|
||||
}
|
||||
|
||||
static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch computeBatchOp, size_t& fallbackCoreId) {
|
||||
if (auto coreIdsAttr = computeBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName))
|
||||
@@ -28,54 +36,73 @@ static SmallVector<int32_t> getPimCoreIdsForBatchOp(spatial::SpatComputeBatch co
|
||||
return coreIds;
|
||||
}
|
||||
|
||||
static void lowerChannelSendTensorBatch(spatial::SpatChannelSendTensorBatchOp sendTensorBatchOp,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
targetCoreIds.reserve(sendTensorBatchOp.getTargetCoreIds().size());
|
||||
for (int32_t targetCoreId : sendTensorBatchOp.getTargetCoreIds())
|
||||
targetCoreIds.push_back(translateSpatialCoreIdToPimCoreId(targetCoreId));
|
||||
static FailureOr<unsigned> getDirectReturnOperandIndex(OpResult result) {
|
||||
if (!result.hasOneUse())
|
||||
return failure();
|
||||
|
||||
pim::PimSendTensorBatchOp::create(rewriter,
|
||||
sendTensorBatchOp.getLoc(),
|
||||
mapper.lookup(sendTensorBatchOp.getInput()),
|
||||
rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
auto returnOp = dyn_cast<func::ReturnOp>(*result.getUsers().begin());
|
||||
if (!returnOp)
|
||||
return failure();
|
||||
return result.getUses().begin()->getOperandNumber();
|
||||
}
|
||||
|
||||
static void lowerChannelReceiveTensorBatch(spatial::SpatChannelReceiveTensorBatchOp receiveTensorBatchOp,
|
||||
IRMapping& mapper,
|
||||
IRRewriter& rewriter) {
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
sourceCoreIds.reserve(receiveTensorBatchOp.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : receiveTensorBatchOp.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(translateSpatialCoreIdToPimCoreId(sourceCoreId));
|
||||
static Value createScaledIndexValue(IRRewriter& rewriter, Location loc, Value base, int64_t scale) {
|
||||
if (scale == 1)
|
||||
return base;
|
||||
|
||||
auto outputType = cast<ShapedType>(receiveTensorBatchOp.getOutput().getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorBatchOp.getLoc(), outputType);
|
||||
Value received = pim::PimReceiveTensorBatchOp::create(rewriter,
|
||||
receiveTensorBatchOp.getLoc(),
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||
.getOutput();
|
||||
mapper.map(receiveTensorBatchOp.getOutput(), received);
|
||||
auto scaleValue = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), scale);
|
||||
return arith::MulIOp::create(rewriter, loc, base, scaleValue).getResult();
|
||||
}
|
||||
|
||||
static Value createHostTargetOffset(IRRewriter& rewriter,
|
||||
tensor::ParallelInsertSliceOp insertSlice,
|
||||
ShapedType destinationType,
|
||||
IRMapping& mapper) {
|
||||
int64_t elementBytes = static_cast<int64_t>(getElementTypeSizeInBytes(destinationType.getElementType()));
|
||||
SmallVector<int64_t> strides = computeRowMajorStrides(destinationType.getShape());
|
||||
|
||||
Value totalOffset;
|
||||
Location loc = insertSlice.getLoc();
|
||||
for (auto [dim, offset] : llvm::enumerate(insertSlice.getMixedOffsets())) {
|
||||
int64_t scale = strides[dim] * elementBytes;
|
||||
Value scaledOffset;
|
||||
if (auto attr = dyn_cast<Attribute>(offset)) {
|
||||
auto intAttr = dyn_cast<IntegerAttr>(attr);
|
||||
assert(intAttr && "expected integer offset attribute");
|
||||
scaledOffset =
|
||||
getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), intAttr.getInt() * scale);
|
||||
}
|
||||
else {
|
||||
scaledOffset = createScaledIndexValue(rewriter, loc, mapper.lookupOrDefault(cast<Value>(offset)), scale);
|
||||
}
|
||||
|
||||
totalOffset =
|
||||
totalOffset ? arith::AddIOp::create(rewriter, loc, totalOffset, scaledOffset).getResult() : scaledOffset;
|
||||
}
|
||||
|
||||
if (!totalOffset)
|
||||
totalOffset = getOrCreateIndexConstant(rewriter, rewriter.getInsertionBlock()->getParentOp(), 0);
|
||||
return totalOffset;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult
|
||||
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
||||
if (computeBatchOp.getNumResults() != 0)
|
||||
return computeBatchOp.emitOpError(
|
||||
"batched Spatial-to-PIM lowering currently requires channelized compute_batch with no results");
|
||||
|
||||
LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp,
|
||||
IRRewriter& rewriter) {
|
||||
Location loc = computeBatchOp.getLoc();
|
||||
Block& oldBlock = computeBatchOp.getBody().front();
|
||||
auto oldYield = cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
if (oldYield.getNumOperands() != 0)
|
||||
return computeBatchOp.emitOpError("batched Spatial-to-PIM lowering currently requires empty spat.yield");
|
||||
auto oldYield = dyn_cast<spatial::SpatYieldOp>(oldBlock.getTerminator());
|
||||
auto inParallelOp = dyn_cast<spatial::SpatInParallelOp>(oldBlock.getTerminator());
|
||||
if (computeBatchOp.getNumResults() == 0) {
|
||||
if (!oldYield || oldYield.getNumOperands() != 0)
|
||||
return computeBatchOp.emitOpError("resultless compute_batch lowering requires empty spat.yield");
|
||||
}
|
||||
else if (!inParallelOp) {
|
||||
return computeBatchOp.emitOpError(
|
||||
"resultful compute_batch lowering currently requires a spat.in_parallel terminator");
|
||||
}
|
||||
|
||||
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, state.nextCoreId);
|
||||
SmallVector<int32_t> coreIds = getPimCoreIdsForBatchOp(computeBatchOp, coreId);
|
||||
SmallVector<Value> batchWeights(computeBatchOp.getWeights().begin(), computeBatchOp.getWeights().end());
|
||||
SmallVector<Value> batchInputs;
|
||||
if (!computeBatchOp.getInputs().empty())
|
||||
@@ -91,9 +118,22 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
|
||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
SmallVector<unsigned> returnOperandIndices;
|
||||
if (computeBatchOp.getNumResults() != 0) {
|
||||
returnOperandIndices.resize(computeBatchOp.getNumResults());
|
||||
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
|
||||
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
|
||||
if (failed(returnOperandIndex))
|
||||
return computeBatchOp.emitOpError(
|
||||
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
|
||||
returnOperandIndices[resultIndex] = *returnOperandIndex;
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<Type> blockArgTypes;
|
||||
SmallVector<Location> blockArgLocs;
|
||||
for (BlockArgument arg : oldBlock.getArguments()) {
|
||||
unsigned inputArgLimit = 1 + computeBatchOp.getWeights().size() + computeBatchOp.getInputs().size();
|
||||
for (BlockArgument arg : oldBlock.getArguments().take_front(inputArgLimit)) {
|
||||
blockArgTypes.push_back(arg.getType());
|
||||
blockArgLocs.push_back(arg.getLoc());
|
||||
}
|
||||
@@ -102,38 +142,44 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
|
||||
IRMapping mapper;
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
for (auto [oldArg, newArg] : llvm::zip(oldBlock.getArguments(), newBlock->getArguments())) {
|
||||
auto oldLaneArg = computeBatchOp.getLaneArgument();
|
||||
if (!oldLaneArg)
|
||||
return computeBatchOp.emitOpError("expected compute_batch lane block argument before lowering");
|
||||
mapper.map(*oldLaneArg, coreBatchOp.getLaneArgument());
|
||||
for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex) {
|
||||
auto oldWeightArg = computeBatchOp.getWeightArgument(weightIndex);
|
||||
if (!oldWeightArg)
|
||||
return computeBatchOp.emitOpError("expected compute_batch weight block arguments before lowering");
|
||||
mapper.map(*oldWeightArg, coreBatchOp.getWeightArgument(weightIndex));
|
||||
}
|
||||
for (unsigned inputIndex = 0; inputIndex < computeBatchOp.getInputs().size(); ++inputIndex) {
|
||||
auto oldArg = computeBatchOp.getInputArgument(inputIndex);
|
||||
if (!oldArg)
|
||||
return computeBatchOp.emitOpError("expected compute_batch input block arguments before lowering");
|
||||
BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex);
|
||||
auto newArgType = cast<ShapedType>(newArg.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
|
||||
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
|
||||
auto copied = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
zeroOffset,
|
||||
zeroOffset,
|
||||
outputBuffer,
|
||||
newArg,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
getTensorSizeInBytesAttr(rewriter, newArg))
|
||||
.getOutput();
|
||||
mapper.map(oldArg, copied);
|
||||
mapper.map(*oldArg, copied);
|
||||
}
|
||||
|
||||
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
|
||||
if (auto mapped = mapper.lookupOrNull(capturedTensor))
|
||||
return mapped;
|
||||
SmallVector<Value> hostOutputTensors(returnOperandIndices.size());
|
||||
auto getOrCreateHostOutputTensor = [&](unsigned resultIndex, Location resultLoc) -> Value {
|
||||
Value& hostOutputTensor = hostOutputTensors[resultIndex];
|
||||
if (hostOutputTensor)
|
||||
return hostOutputTensor;
|
||||
|
||||
auto capturedType = cast<ShapedType>(capturedTensor.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, capturedType);
|
||||
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
capturedTensor,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
getTensorSizeInBytesAttr(rewriter, capturedTensor))
|
||||
.getOutput();
|
||||
mapper.map(capturedTensor, copied);
|
||||
return copied;
|
||||
hostOutputTensor = outputTensors[returnOperandIndices[resultIndex]](rewriter, resultLoc);
|
||||
return hostOutputTensor;
|
||||
};
|
||||
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
@@ -141,36 +187,37 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
if (isa<spatial::SpatYieldOp>(op))
|
||||
continue;
|
||||
|
||||
if (auto sendBatchOp = dyn_cast<spatial::SpatChannelSendBatchOp>(op)) {
|
||||
pim::PimSendBatchOp::create(rewriter,
|
||||
loc,
|
||||
mapper.lookup(sendBatchOp.getInput()),
|
||||
getTensorSizeInBytesAttr(rewriter, mapper.lookup(sendBatchOp.getInput())),
|
||||
sendBatchOp.getTargetCoreIdsAttr());
|
||||
continue;
|
||||
}
|
||||
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
|
||||
auto firstOutputArg = computeBatchOp.getOutputArgument(0);
|
||||
if (!firstOutputArg)
|
||||
return computeBatchOp.emitOpError("expected compute_batch output block arguments before lowering");
|
||||
for (Operation& nestedOp : parallelOp.getRegion().front()) {
|
||||
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&nestedOp);
|
||||
if (!insertSlice)
|
||||
return parallelOp.emitOpError("expected only tensor.parallel_insert_slice in spat.in_parallel");
|
||||
|
||||
if (auto sendTensorBatchOp = dyn_cast<spatial::SpatChannelSendTensorBatchOp>(op)) {
|
||||
lowerChannelSendTensorBatch(sendTensorBatchOp, mapper, rewriter);
|
||||
continue;
|
||||
}
|
||||
auto outputArg = dyn_cast<BlockArgument>(insertSlice.getDest());
|
||||
if (!outputArg || outputArg.getOwner() != &oldBlock)
|
||||
return insertSlice.emitOpError("expected compute_batch output block argument destination");
|
||||
|
||||
if (auto receiveBatchOp = dyn_cast<spatial::SpatChannelReceiveBatchOp>(op)) {
|
||||
auto outputType = cast<ShapedType>(receiveBatchOp.getOutput().getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, outputType);
|
||||
auto received = pim::PimReceiveBatchOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
outputBuffer,
|
||||
getTensorSizeInBytesAttr(rewriter, receiveBatchOp.getOutput()),
|
||||
receiveBatchOp.getSourceCoreIdsAttr())
|
||||
.getOutput();
|
||||
mapper.map(receiveBatchOp.getOutput(), received);
|
||||
continue;
|
||||
}
|
||||
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
|
||||
if (resultIndex >= returnOperandIndices.size())
|
||||
return insertSlice.emitOpError("result index out of range while lowering host batch output");
|
||||
|
||||
if (auto receiveTensorBatchOp = dyn_cast<spatial::SpatChannelReceiveTensorBatchOp>(op)) {
|
||||
lowerChannelReceiveTensorBatch(receiveTensorBatchOp, mapper, rewriter);
|
||||
Value mappedSource = mapper.lookup(insertSlice.getSource());
|
||||
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
|
||||
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
|
||||
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
|
||||
pim::PimMemCopyDevToHostOp::create(rewriter,
|
||||
insertSlice.getLoc(),
|
||||
hostTarget.getType(),
|
||||
hostTargetOffset,
|
||||
zeroOffset,
|
||||
hostTarget,
|
||||
mappedSource,
|
||||
getTensorSizeInBytesAttr(rewriter, mappedSource));
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -178,15 +225,20 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
if (isa_and_present<memref::GetGlobalOp>(toTensorOp.getBuffer().getDefiningOp())) {
|
||||
Operation* cloned = rewriter.clone(op, mapper);
|
||||
auto clonedTensor = cloned->getResult(0);
|
||||
if (isUsedOnlyAsExplicitHostOperand(toTensorOp.getResult())) {
|
||||
mapper.map(toTensorOp.getResult(), clonedTensor);
|
||||
continue;
|
||||
}
|
||||
auto clonedType = cast<ShapedType>(clonedTensor.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, clonedType);
|
||||
auto copied = pim::PimMemCopyHostToDevBatchOp::create(rewriter,
|
||||
Value zeroOffset = getOrCreateIndexConstant(rewriter, coreBatchOp.getOperation(), 0);
|
||||
auto copied = pim::PimMemCopyHostToDevOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
zeroOffset,
|
||||
zeroOffset,
|
||||
outputBuffer,
|
||||
clonedTensor,
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
getTensorSizeInBytesAttr(rewriter, clonedTensor))
|
||||
.getOutput();
|
||||
mapper.map(toTensorOp.getResult(), copied);
|
||||
@@ -194,15 +246,18 @@ lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState&
|
||||
}
|
||||
}
|
||||
|
||||
for (Value operand : op.getOperands()) {
|
||||
for (auto [operandIndex, operand] : llvm::enumerate(op.getOperands())) {
|
||||
if (!isa<TensorType>(operand.getType()) || mapper.contains(operand))
|
||||
continue;
|
||||
if (isExplicitDevToHostTargetOperand(&op, operandIndex))
|
||||
continue;
|
||||
|
||||
Operation* definingOp = operand.getDefiningOp();
|
||||
if (definingOp && definingOp->getBlock() == &oldBlock)
|
||||
continue;
|
||||
|
||||
materializeCapturedTensor(operand);
|
||||
return computeBatchOp.emitOpError(
|
||||
"expected external tensor communication to be materialized in Spatial before batch lowering");
|
||||
}
|
||||
|
||||
Operation* cloned = rewriter.clone(op, mapper);
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerComputeBatchOp(spatial::SpatComputeBatch computeBatchOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -3,17 +3,17 @@ mlir_tablegen(SpatialToPim.hpp.inc -gen-rewriters "-I${ONNX_MLIR_SRC_ROOT}")
|
||||
add_public_tablegen_target(SpatialToPimIncGen)
|
||||
|
||||
add_pim_library(OMSpatialToPim
|
||||
Patterns.cpp
|
||||
SpatialToPimPass.cpp
|
||||
BatchCoreLoweringPatterns.cpp
|
||||
ChannelLoweringPatterns.cpp
|
||||
Cleanup.cpp
|
||||
Common.cpp
|
||||
ComputeLikeRegionUtils.cpp
|
||||
CoreLoweringPatterns.cpp
|
||||
GlobalTensorMaterialization.cpp
|
||||
PhaseVerification.cpp
|
||||
ReturnPathNormalization.cpp
|
||||
TensorPackingPatterns.cpp
|
||||
Patterns/ChannelLowering.cpp
|
||||
Patterns/GlobalTensorMaterialization.cpp
|
||||
Patterns/TensorPacking.cpp
|
||||
Patterns/Transpose.cpp
|
||||
|
||||
EXCLUDE_FROM_OM_LIBS
|
||||
|
||||
@@ -21,7 +21,10 @@ add_pim_library(OMSpatialToPim
|
||||
SpatialToPimIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRLinalgDialect
|
||||
MLIRSCFDialect
|
||||
MLIRSCFUtils
|
||||
MLIRTransformUtils
|
||||
MLIRTosaDialect
|
||||
OMCompilerOptions
|
||||
OMPimCommon
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,42 +0,0 @@
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Cleanup.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
LogicalResult erasePendingOps(SmallVectorImpl<Operation*>& pendingOps, IRRewriter& rewriter) {
|
||||
while (!pendingOps.empty()) {
|
||||
bool erasedAnyOp = false;
|
||||
for (auto it = pendingOps.begin(); it != pendingOps.end();) {
|
||||
Operation* opToRemove = *it;
|
||||
if (!opToRemove->use_empty()) {
|
||||
++it;
|
||||
continue;
|
||||
}
|
||||
|
||||
rewriter.eraseOp(opToRemove);
|
||||
it = pendingOps.erase(it);
|
||||
erasedAnyOp = true;
|
||||
}
|
||||
|
||||
if (erasedAnyOp)
|
||||
continue;
|
||||
|
||||
for (Operation* opToRemove : pendingOps) {
|
||||
InFlightDiagnostic diag = opToRemove->emitError("pending Spatial-to-PIM cleanup could not erase operation");
|
||||
diag << "; op has " << llvm::range_size(opToRemove->getUsers()) << " remaining user(s)";
|
||||
for (Operation* user : opToRemove->getUsers()) {
|
||||
bool userPendingRemoval = llvm::is_contained(pendingOps, user);
|
||||
opToRemove->emitRemark() << "remaining user `" << user->getName() << "`"
|
||||
<< (userPendingRemoval ? " is also pending removal" : " is not pending removal");
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,11 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult erasePendingOps(llvm::SmallVectorImpl<mlir::Operation*>& pendingOps, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,10 +1,8 @@
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
|
||||
#include "Common.hpp"
|
||||
|
||||
@@ -13,52 +11,6 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputShape) {
|
||||
/*
|
||||
EXAMPLE RUN:
|
||||
[1, 10, 3, 4] inputShape
|
||||
[0, 2, 1, 3] offsets
|
||||
|
||||
acc = 1
|
||||
---
|
||||
ret = 3
|
||||
acc = 4
|
||||
---
|
||||
ret = 3 + 4 * 1 = 7
|
||||
acc = 12
|
||||
---
|
||||
ret = 7 + 12 * 2 = 31
|
||||
acc = 120
|
||||
---
|
||||
ret = 31 + 120 * 0 = 31
|
||||
acc = 120
|
||||
*/
|
||||
|
||||
size_t returnValue = 0;
|
||||
|
||||
auto sliceOffsets = sliceOp.getStaticOffsets();
|
||||
auto inputDimSizes = inputShape.getShape();
|
||||
|
||||
assert(sliceOffsets.size() == inputDimSizes.size());
|
||||
|
||||
size_t accumulatedDimensionSize = 1;
|
||||
|
||||
// Reverse iterate the two vectors
|
||||
for (auto it : reverse(zip(sliceOffsets, inputDimSizes))) {
|
||||
auto curSliceOffset = std::get<0>(it);
|
||||
auto curInputDimSize = std::get<1>(it);
|
||||
|
||||
returnValue += accumulatedDimensionSize * curSliceOffset;
|
||||
accumulatedDimensionSize *= curInputDimSize;
|
||||
}
|
||||
|
||||
return returnValue;
|
||||
}
|
||||
|
||||
size_t getShapedTypeSizeInBytes(ShapedType shapedType) {
|
||||
return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8;
|
||||
}
|
||||
|
||||
IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) {
|
||||
return builder.getI32IntegerAttr(static_cast<int32_t>(getShapedTypeSizeInBytes(cast<ShapedType>(value.getType()))));
|
||||
}
|
||||
|
||||
@@ -6,22 +6,6 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
/**
|
||||
* \brief Get the offset of the ExtractSliceOp based on its static offsets and
|
||||
* its static tensor input.
|
||||
*
|
||||
* The static offsets represent the starting position of the slice in each
|
||||
* dimension, while the static tensor input gives its dimension size.
|
||||
*
|
||||
* \param sliceOp The ExtractSliceOp for which the actual offset needs to be
|
||||
* calculated.
|
||||
* \param inputShape The ShapedType of the ExtractSliceOp's input tensor
|
||||
* \return The actual offset of the ExtractSliceOp.
|
||||
*/
|
||||
size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape);
|
||||
|
||||
size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType);
|
||||
|
||||
mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value);
|
||||
|
||||
template <class T>
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
@@ -29,7 +31,18 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
||||
unsigned inputIndex,
|
||||
Value replacement) {
|
||||
Block& body = owner->getRegion(0).front();
|
||||
BlockArgument bodyArgument = body.getArgument(inputIndex);
|
||||
BlockArgument bodyArgument;
|
||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
||||
auto computeArg = compute.getInputArgument(inputIndex);
|
||||
assert(computeArg && "expected compute input block argument");
|
||||
bodyArgument = *computeArg;
|
||||
}
|
||||
else {
|
||||
auto batchArg = cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
|
||||
assert(batchArg && "expected compute_batch input block argument");
|
||||
bodyArgument = *batchArg;
|
||||
}
|
||||
unsigned bodyArgIndex = bodyArgument.getArgNumber();
|
||||
|
||||
rewriter.startOpModification(owner);
|
||||
bodyArgument.replaceAllUsesWith(replacement);
|
||||
@@ -37,7 +50,7 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
||||
compute.getInputsMutable().erase(inputIndex);
|
||||
else
|
||||
cast<spatial::SpatComputeBatch>(owner).getInputsMutable().erase(inputIndex);
|
||||
body.eraseArgument(inputIndex);
|
||||
body.eraseArgument(bodyArgIndex);
|
||||
rewriter.finalizeOpModification(owner);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/CoreLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -23,11 +25,12 @@ static bool isChannelUseChainOp(Operation* op) {
|
||||
tensor::ExpandShapeOp,
|
||||
tensor::CastOp,
|
||||
tosa::ReshapeOp,
|
||||
ONNXTransposeOp,
|
||||
linalg::TransposeOp,
|
||||
pim::PimTransposeOp>(op);
|
||||
}
|
||||
|
||||
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
||||
static void
|
||||
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (mapping.lookupOrNull(operand))
|
||||
continue;
|
||||
@@ -36,7 +39,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
|
||||
if (!definingOp)
|
||||
continue;
|
||||
|
||||
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
||||
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
|
||||
mapping.map(operand, getOrCreateConstantLike(constantFolder, constantOp));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<tensor::EmptyOp>(definingOp))
|
||||
continue;
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||
@@ -46,8 +54,6 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
|
||||
}
|
||||
}
|
||||
|
||||
static int32_t translateSpatialCoreIdToPimCoreId(size_t spatialCoreId) { return static_cast<int32_t>(spatialCoreId); }
|
||||
|
||||
static int32_t getPimCoreIdForComputeOp(spatial::SpatCompute computeOp, size_t& fallbackCoreId) {
|
||||
if (auto spatialCoreIdAttr = computeOp->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
return static_cast<int32_t>(spatialCoreIdAttr.getInt());
|
||||
@@ -92,7 +98,9 @@ static LogicalResult collectHelperComputeChain(spatial::SpatCompute computeOp,
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp, IRRewriter& rewriter) {
|
||||
static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute computeOp,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder) {
|
||||
if (!computeOp.getInputs().empty() || computeOp.getNumResults() != 1)
|
||||
return false;
|
||||
if (!llvm::all_of(computeOp.getResult(0).getUsers(), [](Operation* user) {
|
||||
@@ -101,7 +109,7 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
||||
return false;
|
||||
|
||||
Block& block = computeOp.getBody().front();
|
||||
if (block.getNumArguments() != 0)
|
||||
if (block.getNumArguments() != computeOp.getWeights().size())
|
||||
return false;
|
||||
|
||||
auto yieldOp = dyn_cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||
@@ -110,8 +118,14 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
||||
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
IRMapping mapping;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights())) {
|
||||
auto weightArg = computeOp.getWeightArgument(weightIndex);
|
||||
if (!weightArg)
|
||||
return false;
|
||||
mapping.map(*weightArg, weight);
|
||||
}
|
||||
for (Operation& op : block.without_terminator()) {
|
||||
cloneMappedHelperOperands(&op, mapping, rewriter);
|
||||
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
|
||||
Operation* clonedOp = rewriter.clone(op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op.getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
@@ -125,15 +139,12 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
||||
|
||||
} // namespace
|
||||
|
||||
void markOpToRemove(CoreLoweringState& state, Operation* op) {
|
||||
if (!llvm::is_contained(state.operationsToRemove, op))
|
||||
state.operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, IRRewriter& rewriter) {
|
||||
LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute computeOp,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder) {
|
||||
Location loc = computeOp->getLoc();
|
||||
|
||||
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter))
|
||||
if (inlineInputlessHelperComputeForWeightLikeUsers(computeOp, rewriter, constantFolder))
|
||||
return success();
|
||||
|
||||
SmallVector<Operation*> helperChain;
|
||||
@@ -143,21 +154,24 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
auto& block = computeOp.getRegion().front();
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||
|
||||
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments())) {
|
||||
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(computeOp.getInputs()[argIndex].getDefiningOp());
|
||||
if (!receiveOp || blockArg.use_empty())
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg.getType());
|
||||
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||
auto blockArg = computeOp.getInputArgument(inputIndex);
|
||||
if (!blockArg)
|
||||
return computeOp.emitOpError("expected compute input block arguments during lowering");
|
||||
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
|
||||
if (receiveOp && !blockArg->use_empty()) {
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg->getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
|
||||
auto sourceCoreIdAttr = rewriter.getI32IntegerAttr(translateSpatialCoreIdToPimCoreId(receiveOp.getSourceCoreId()));
|
||||
Value received = PimReceiveOp::create(
|
||||
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, sourceCoreIdAttr)
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, *blockArg);
|
||||
Value received =
|
||||
PimReceiveOp::create(
|
||||
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId())
|
||||
.getOutput();
|
||||
blockArg.replaceAllUsesWith(received);
|
||||
markOpToRemove(state, receiveOp);
|
||||
blockArg->replaceAllUsesWith(received);
|
||||
markOpToRemove(receiveOp);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (computeOp.getNumResults() != yieldOp.getNumOperands())
|
||||
@@ -167,9 +181,8 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
|
||||
ReturnPathState returnPathState {state.outputTensors, state.operationsToRemove};
|
||||
ReturnPathLoweringResult returnPathResult =
|
||||
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, returnPathState, rewriter);
|
||||
lowerComputeResultReturnPath(computeOp, cast<OpResult>(result), yieldValue, rewriter);
|
||||
if (returnPathResult == ReturnPathLoweringResult::Failure)
|
||||
return failure();
|
||||
if (returnPathResult == ReturnPathLoweringResult::Handled)
|
||||
@@ -193,15 +206,40 @@ LogicalResult lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState&
|
||||
if (!computeOp.getWeights().empty())
|
||||
computeWeights.append(computeOp.getWeights().begin(), computeOp.getWeights().end());
|
||||
rewriter.setInsertionPointAfter(computeOp);
|
||||
auto coreOp = PimCoreOp::create(rewriter,
|
||||
loc,
|
||||
ValueRange(computeWeights),
|
||||
rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, state.nextCoreId)));
|
||||
auto coreOp = PimCoreOp::create(
|
||||
rewriter, loc, ValueRange(computeWeights), rewriter.getI32IntegerAttr(getPimCoreIdForComputeOp(computeOp, coreId)));
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||
for (auto [argIndex, blockArg] : llvm::enumerate(block.getArguments()))
|
||||
if (!blockArg.use_empty())
|
||||
blockArg.replaceAllUsesWith(computeOp.getInputs()[argIndex]);
|
||||
block.eraseArguments(0, block.getNumArguments());
|
||||
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||
auto blockArg = computeOp.getInputArgument(inputIndex);
|
||||
if (!blockArg)
|
||||
return computeOp.emitOpError("expected compute input block arguments during input materialization");
|
||||
if (blockArg->use_empty())
|
||||
continue;
|
||||
|
||||
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
|
||||
blockArg->replaceAllUsesWith(getOrCreateConstantLike(constantFolder, constantOp));
|
||||
continue;
|
||||
}
|
||||
|
||||
auto inputType = dyn_cast<ShapedType>(input.getType());
|
||||
if (!inputType)
|
||||
return computeOp.emitOpError("expected shaped compute input during pim.core lowering");
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, inputType);
|
||||
auto copied =
|
||||
PimMemCopyHostToDevOp::create(rewriter,
|
||||
loc,
|
||||
outputBuffer.getType(),
|
||||
getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0),
|
||||
getOrCreateIndexConstant(constantFolder, outputBuffer.getOperation(), 0),
|
||||
outputBuffer,
|
||||
input,
|
||||
getTensorSizeInBytesAttr(rewriter, input))
|
||||
.getOutput();
|
||||
blockArg->replaceAllUsesWith(copied);
|
||||
}
|
||||
if (!computeOp.getInputs().empty())
|
||||
block.eraseArguments(computeOp.getWeights().size(), computeOp.getInputs().size());
|
||||
coreOpBlocks.splice(coreOpBlocks.begin(), computeOp.getBody().getBlocks());
|
||||
Block* tempComputeBlock = new Block();
|
||||
computeOp.getBody().push_back(tempComputeBlock);
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
struct CoreLoweringState {
|
||||
size_t& nextCoreId;
|
||||
llvm::SmallVectorImpl<OutputTensorFactory>& outputTensors;
|
||||
llvm::SmallVectorImpl<mlir::Operation*>& operationsToRemove;
|
||||
};
|
||||
|
||||
void markOpToRemove(CoreLoweringState& state, mlir::Operation* op);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerComputeOp(spatial::SpatCompute computeOp, CoreLoweringState& state, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,9 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace raptor {
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/SpatialToPim.hpp.inc"
|
||||
|
||||
} // namespace raptor
|
||||
|
||||
void populateInitialPatterns(RewritePatternSet& patterns) {
|
||||
raptor::populateWithGenerated(patterns);
|
||||
populateTransposeLoweringPatterns(patterns);
|
||||
}
|
||||
|
||||
void populateCoreBodyPatterns(RewritePatternSet& patterns) {
|
||||
raptor::populateWithGenerated(patterns);
|
||||
populateTransposeLoweringPatterns(patterns);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+8
-1
@@ -8,6 +8,14 @@
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
void populateInitialPatterns(mlir::RewritePatternSet& patterns);
|
||||
void populateCoreBodyPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
void populateTransposeLoweringPatterns(mlir::RewritePatternSet& patterns);
|
||||
void populateChannelLoweringPatterns(mlir::RewritePatternSet& patterns);
|
||||
void populateGlobalTensorMaterializationPatterns(mlir::RewritePatternSet& patterns);
|
||||
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
|
||||
|
||||
mlir::RankedTensorType getPackedTensorType(mlir::RankedTensorType elementType, int64_t count);
|
||||
mlir::Value extractPackedChunk(mlir::Value packedValue,
|
||||
mlir::RankedTensorType chunkType,
|
||||
@@ -20,7 +28,6 @@ mlir::Value createPackedExtractRowsSlice(spatial::SpatExtractRowsOp extractRowsO
|
||||
mlir::OpBuilder& builder,
|
||||
mlir::Location loc);
|
||||
mlir::Value createPackedExtractSliceTensor(mlir::ValueRange values, mlir::OpBuilder& builder, mlir::Location loc);
|
||||
void populateTensorPackingPatterns(mlir::RewritePatternSet& patterns);
|
||||
void eraseUnusedTensorPackingOps(mlir::func::FuncOp funcOp, mlir::IRRewriter& rewriter);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+5
-49
@@ -1,7 +1,7 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ChannelLoweringPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
@@ -10,17 +10,12 @@ using namespace mlir;
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
static int32_t toPimCoreId(int32_t spatialCoreId) { return spatialCoreId; }
|
||||
|
||||
struct ChannelSendLowering : OpRewritePattern<spatial::SpatChannelSendOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelSendOp op, PatternRewriter& rewriter) const override {
|
||||
pim::PimSendOp::create(rewriter,
|
||||
op.getLoc(),
|
||||
op.getInput(),
|
||||
getTensorSizeInBytesAttr(rewriter, op.getInput()),
|
||||
rewriter.getI32IntegerAttr(toPimCoreId(op.getTargetCoreId())));
|
||||
pim::PimSendOp::create(
|
||||
rewriter, op.getLoc(), op.getInput(), getTensorSizeInBytesAttr(rewriter, op.getInput()), op.getTargetCoreId());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
@@ -42,41 +37,7 @@ struct ChannelReceiveLowering : OpRewritePattern<spatial::SpatChannelReceiveOp>
|
||||
op.getResult().getType(),
|
||||
outputBuffer,
|
||||
getTensorSizeInBytesAttr(rewriter, op.getResult()),
|
||||
rewriter.getI32IntegerAttr(toPimCoreId(op.getSourceCoreId())))
|
||||
.getOutput();
|
||||
rewriter.replaceOp(op, received);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelSendTensorLowering : OpRewritePattern<spatial::SpatChannelSendTensorOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelSendTensorOp op, PatternRewriter& rewriter) const override {
|
||||
SmallVector<int32_t> targetCoreIds;
|
||||
targetCoreIds.reserve(op.getTargetCoreIds().size());
|
||||
for (int32_t targetCoreId : op.getTargetCoreIds())
|
||||
targetCoreIds.push_back(toPimCoreId(targetCoreId));
|
||||
pim::PimSendTensorOp::create(rewriter, op.getLoc(), op.getInput(), rewriter.getDenseI32ArrayAttr(targetCoreIds));
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ChannelReceiveTensorLowering : OpRewritePattern<spatial::SpatChannelReceiveTensorOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(spatial::SpatChannelReceiveTensorOp op, PatternRewriter& rewriter) const override {
|
||||
SmallVector<int32_t> sourceCoreIds;
|
||||
sourceCoreIds.reserve(op.getSourceCoreIds().size());
|
||||
for (int32_t sourceCoreId : op.getSourceCoreIds())
|
||||
sourceCoreIds.push_back(toPimCoreId(sourceCoreId));
|
||||
auto outputType = cast<ShapedType>(op.getOutput().getType());
|
||||
Value outputBuffer =
|
||||
tensor::EmptyOp::create(rewriter, op.getLoc(), outputType.getShape(), outputType.getElementType()).getResult();
|
||||
Value received =
|
||||
pim::PimReceiveTensorOp::create(
|
||||
rewriter, op.getLoc(), op.getOutput().getType(), outputBuffer, rewriter.getDenseI32ArrayAttr(sourceCoreIds))
|
||||
op.getSourceCoreId())
|
||||
.getOutput();
|
||||
rewriter.replaceOp(op, received);
|
||||
return success();
|
||||
@@ -125,12 +86,7 @@ struct ConcatLowering : OpRewritePattern<spatial::SpatConcatOp> {
|
||||
} // namespace
|
||||
|
||||
void populateChannelLoweringPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<ChannelSendLowering,
|
||||
ChannelReceiveLowering,
|
||||
ChannelSendTensorLowering,
|
||||
ChannelReceiveTensorLowering,
|
||||
ExtractRowsLowering,
|
||||
ConcatLowering>(patterns.getContext());
|
||||
patterns.add<ChannelSendLowering, ChannelReceiveLowering, ExtractRowsLowering, ConcatLowering>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+12
-175
@@ -16,7 +16,7 @@
|
||||
|
||||
#include "Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/GlobalTensorMaterialization.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -76,10 +76,11 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatCompute.getBody().front().getArgument(BBArgIndex);
|
||||
auto BBArgValue = spatCompute.getInputArgument(*inputIndex);
|
||||
if (!BBArgValue)
|
||||
return failure();
|
||||
|
||||
if (BBArgValue.use_empty())
|
||||
if (BBArgValue->use_empty())
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
@@ -89,16 +90,17 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
||||
rewriter, spatCompute.getOperation(), *inputIndex, mapSpatToExtract[spatCompute.getOperation()]);
|
||||
}
|
||||
else if (auto spatComputeBatch = dyn_cast<spatial::SpatComputeBatch>(uses.getOwner())) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, uses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
auto BBArgValue = spatComputeBatch.getBody().front().getArgument(BBArgIndex);
|
||||
auto BBArgValue = spatComputeBatch.getInputArgument(*inputIndex);
|
||||
if (!BBArgValue)
|
||||
return failure();
|
||||
|
||||
if (BBArgValue.use_empty())
|
||||
if (BBArgValue->use_empty())
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
@@ -108,7 +110,7 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatComputeBatch.getOperation(), BBArgIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||
rewriter, spatComputeBatch.getOperation(), *inputIndex, mapSpatToExtract[spatComputeBatch.getOperation()]);
|
||||
}
|
||||
else {
|
||||
{
|
||||
@@ -143,170 +145,6 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
}
|
||||
};
|
||||
|
||||
// Turns runtime constants consumed by compute regions into private globals and local loads.
|
||||
struct ArithConstToGlobalMemoryPattern final : OpRewritePattern<mlir::arith::ConstantOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::arith::ConstantOp constantOp, PatternRewriter& rewriter) const override {
|
||||
Location loc = constantOp.getLoc();
|
||||
|
||||
if (hasWeightAlways(constantOp))
|
||||
return failure();
|
||||
|
||||
if (!isa<func::FuncOp>(constantOp->getParentOp()))
|
||||
return failure();
|
||||
|
||||
if (llvm::all_of(constantOp->getUsers(), [](Operation* op) {
|
||||
if (isa<spatial::SpatCompute>(op))
|
||||
return false;
|
||||
if (isa<func::FuncOp>(op->getParentOp()))
|
||||
return true;
|
||||
return false;
|
||||
}))
|
||||
return failure();
|
||||
|
||||
rewriter.setInsertionPoint(constantOp->getParentOfType<func::FuncOp>());
|
||||
|
||||
auto constRankedTensorType = llvm::dyn_cast<mlir::RankedTensorType>(constantOp.getType());
|
||||
|
||||
if (constRankedTensorType) {
|
||||
mlir::MemRefType memRefType =
|
||||
mlir::MemRefType::get(constRankedTensorType.getShape(), constRankedTensorType.getElementType());
|
||||
auto globalOp = createPrivateMemrefGlobalWithUniqueName(rewriter,
|
||||
loc,
|
||||
constantOp->getParentOfType<ModuleOp>(),
|
||||
"const",
|
||||
memRefType,
|
||||
constantOp.getValueAttr(),
|
||||
rewriter.getUnitAttr());
|
||||
std::string argName = globalOp.getSymName().str();
|
||||
|
||||
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
||||
|
||||
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||
auto constUsers = constUses.getOwner();
|
||||
|
||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), BBArgIndex, mapSpatComputeToConst[spatCompute.getOperation()]);
|
||||
}
|
||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(rewriter,
|
||||
spatComputeBatch.getOperation(),
|
||||
BBArgIndex,
|
||||
mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
||||
}
|
||||
else {
|
||||
{
|
||||
|
||||
if (auto spatCompute = constUses.getOwner()->getParentOfType<spatial::SpatCompute>()) {
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatCompute.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
mapSpatComputeToConst.insert({spatCompute.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
rewriter.startOpModification(spatCompute.getOperation());
|
||||
constUses.set(mapSpatComputeToConst[spatCompute.getOperation()]);
|
||||
rewriter.finalizeOpModification(spatCompute.getOperation());
|
||||
}
|
||||
else if (auto spatComputeBatch = constUses.getOwner()->getParentOfType<spatial::SpatComputeBatch>()) {
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
if (!mapSpatComputeToConst.contains(spatComputeBatch.getOperation())) {
|
||||
auto getGlobalOp = memref::GetGlobalOp::create(rewriter, loc, memRefType, argName);
|
||||
auto toTensor = bufferization::ToTensorOp::create(
|
||||
rewriter, loc, constRankedTensorType, getGlobalOp, rewriter.getUnitAttr(), rewriter.getUnitAttr());
|
||||
mapSpatComputeToConst.insert({spatComputeBatch.getOperation(), toTensor.getResult()});
|
||||
}
|
||||
|
||||
rewriter.startOpModification(spatComputeBatch.getOperation());
|
||||
constUses.set(mapSpatComputeToConst[spatComputeBatch.getOperation()]);
|
||||
rewriter.finalizeOpModification(spatComputeBatch.getOperation());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (constantOp.getType().isIntOrIndexOrFloat()) {
|
||||
llvm::DenseMap<Operation*, Value> mapSpatComputeToConst;
|
||||
|
||||
for (auto& constUses : llvm::make_early_inc_range(constantOp->getUses())) {
|
||||
auto constUsers = constUses.getOwner();
|
||||
|
||||
if (auto spatCompute = llvm::dyn_cast<spatial::SpatCompute>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatCompute, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatCompute.getOperation(), BBArgIndex, newConst->getResult(0));
|
||||
}
|
||||
else if (auto spatComputeBatch = llvm::dyn_cast<spatial::SpatComputeBatch>(constUsers)) {
|
||||
auto inputIndex = getDirectComputeLikeInputIndex(spatComputeBatch, constUses.getOperandNumber());
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgIndex = *inputIndex;
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
|
||||
replaceAndEraseDirectComputeLikeInput(
|
||||
rewriter, spatComputeBatch.getOperation(), BBArgIndex, newConst->getResult(0));
|
||||
}
|
||||
else if (auto parent = constUsers->getParentOfType<spatial::SpatCompute>()) {
|
||||
if (!mapSpatComputeToConst.contains(parent)) {
|
||||
rewriter.setInsertionPoint(&parent.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
mapSpatComputeToConst.insert({parent.getOperation(), newConst->getResult(0)});
|
||||
}
|
||||
constUses.set(mapSpatComputeToConst[parent.getOperation()]);
|
||||
}
|
||||
else {
|
||||
auto batchParent = constUsers->getParentOfType<spatial::SpatComputeBatch>();
|
||||
assert(batchParent && "Global Constant used direcly not within a compute");
|
||||
if (!mapSpatComputeToConst.contains(batchParent.getOperation())) {
|
||||
rewriter.setInsertionPoint(&batchParent.getBody().front().front());
|
||||
auto newConst = rewriter.clone(*constantOp);
|
||||
mapSpatComputeToConst.insert({batchParent.getOperation(), newConst->getResult(0)});
|
||||
}
|
||||
constUses.set(mapSpatComputeToConst[batchParent.getOperation()]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (constantOp->use_empty())
|
||||
rewriter.eraseOp(constantOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Materializes public function tensor inputs as globals so compute bodies can load them uniformly.
|
||||
struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
@@ -383,8 +221,7 @@ struct FuncOpArgToGlobalMemoryPattern final : OpRewritePattern<mlir::func::FuncO
|
||||
|
||||
} // namespace
|
||||
void populateGlobalTensorMaterializationPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern, ArithConstToGlobalMemoryPattern>(
|
||||
patterns.getContext());
|
||||
patterns.add<MoveExtractSliceIntoCompute, FuncOpArgToGlobalMemoryPattern>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
@@ -0,0 +1,38 @@
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Patterns.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace {
|
||||
|
||||
struct LinalgTransposeToPim final : OpRewritePattern<linalg::TransposeOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter& rewriter) const override {
|
||||
SmallVector<Attribute> permutationAttrs;
|
||||
permutationAttrs.reserve(transposeOp.getPermutation().size());
|
||||
for (int64_t dim : transposeOp.getPermutation())
|
||||
permutationAttrs.push_back(rewriter.getI64IntegerAttr(dim));
|
||||
|
||||
auto permutation = rewriter.getArrayAttr(permutationAttrs);
|
||||
auto pimTranspose = pim::PimTransposeOp::create(rewriter,
|
||||
transposeOp.getLoc(),
|
||||
TypeRange {transposeOp->getResult(0).getType()},
|
||||
transposeOp.getInput(),
|
||||
permutation,
|
||||
transposeOp.getInit());
|
||||
rewriter.replaceOp(transposeOp, pimTranspose.getOutput());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void populateTransposeLoweringPatterns(RewritePatternSet& patterns) {
|
||||
patterns.add<LinalgTransposeToPim>(patterns.getContext());
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,20 +0,0 @@
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/PhaseVerification.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
LogicalResult verifySpatialToPimBoundary(ModuleOp moduleOp) {
|
||||
bool hasFailure = false;
|
||||
moduleOp.walk([&](Operation* op) {
|
||||
if (op->getDialect()->getNamespace() != "spat")
|
||||
return;
|
||||
|
||||
op->emitError("illegal Spatial operation remains after Spatial-to-PIM lowering");
|
||||
hasFailure = true;
|
||||
});
|
||||
return success(!hasFailure);
|
||||
}
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,9 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
|
||||
namespace onnx_mlir {
|
||||
|
||||
mlir::LogicalResult verifySpatialToPimBoundary(mlir::ModuleOp moduleOp);
|
||||
|
||||
} // namespace onnx_mlir
|
||||
@@ -1,15 +1,18 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
|
||||
#include "Conversion/ONNXToSpatial/Common/Common.hpp"
|
||||
#include "Conversion/SpatialToPim/SpatialToPimPass.hpp"
|
||||
#include "src/Accelerators/PIM/Common/PimCommon.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/Common.hpp"
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ReturnPathNormalization.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
|
||||
#include "src/Dialect/ONNX/ONNXOps.hpp"
|
||||
|
||||
@@ -38,15 +41,10 @@ static bool isReturnHelperChainOp(Operation* op) {
|
||||
tensor::ExpandShapeOp,
|
||||
tensor::CastOp,
|
||||
tosa::ReshapeOp,
|
||||
ONNXTransposeOp,
|
||||
linalg::TransposeOp,
|
||||
pim::PimTransposeOp>(op);
|
||||
}
|
||||
|
||||
static void markOpToRemove(ReturnPathState& state, Operation* op) {
|
||||
if (!llvm::is_contained(state.operationsToRemove, op))
|
||||
state.operationsToRemove.push_back(op);
|
||||
}
|
||||
|
||||
static std::string makeUniqueSymbolName(Operation* symbolTableOp, StringRef baseName) {
|
||||
std::string name = baseName.str();
|
||||
unsigned suffix = 0;
|
||||
@@ -279,11 +277,10 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto transposeOp = dyn_cast<ONNXTransposeOp>(op)) {
|
||||
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(op)) {
|
||||
SmallVector<int64_t> nextIndices(currentIndices.size());
|
||||
SmallVector<int64_t> nextShape(currentShape.size());
|
||||
for (auto [destIndex, attr] : llvm::enumerate(transposeOp.getPermAttr().getAsRange<IntegerAttr>())) {
|
||||
int64_t sourceIndex = attr.getInt();
|
||||
for (auto [destIndex, sourceIndex] : llvm::enumerate(transposeOp.getPermutation())) {
|
||||
nextIndices[destIndex] = currentIndices[sourceIndex];
|
||||
nextShape[destIndex] = currentShape[sourceIndex];
|
||||
}
|
||||
@@ -318,7 +315,8 @@ static LogicalResult mapIndicesThroughHelperChain(ArrayRef<int64_t> sourceIndice
|
||||
return success();
|
||||
}
|
||||
|
||||
static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter) {
|
||||
static void
|
||||
cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewriter& rewriter, OperationFolder& constantFolder) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (mapping.lookupOrNull(operand))
|
||||
continue;
|
||||
@@ -327,7 +325,12 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
|
||||
if (!definingOp)
|
||||
continue;
|
||||
|
||||
if (!isa<tensor::EmptyOp, arith::ConstantOp>(definingOp))
|
||||
if (auto constantOp = dyn_cast<arith::ConstantOp>(definingOp)) {
|
||||
mapping.map(operand, getOrCreateConstantLike(constantFolder, constantOp));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!isa<tensor::EmptyOp>(definingOp))
|
||||
continue;
|
||||
|
||||
Operation* clonedOp = rewriter.clone(*definingOp, mapping);
|
||||
@@ -337,15 +340,18 @@ static void cloneMappedHelperOperands(Operation* op, IRMapping& mapping, IRRewri
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
cloneHelperChain(Value sourceValue, ArrayRef<Operation*> helperChain, IRRewriter& rewriter, Value& clonedValue) {
|
||||
static void cloneHelperChain(Value sourceValue,
|
||||
ArrayRef<Operation*> helperChain,
|
||||
IRRewriter& rewriter,
|
||||
OperationFolder& constantFolder,
|
||||
Value& clonedValue) {
|
||||
IRMapping mapping;
|
||||
mapping.map(sourceValue, sourceValue);
|
||||
clonedValue = sourceValue;
|
||||
|
||||
rewriter.setInsertionPointAfterValue(sourceValue);
|
||||
for (Operation* op : helperChain) {
|
||||
cloneMappedHelperOperands(op, mapping, rewriter);
|
||||
cloneMappedHelperOperands(op, mapping, rewriter, constantFolder);
|
||||
Operation* clonedOp = rewriter.clone(*op, mapping);
|
||||
for (auto [originalResult, newResult] : llvm::zip(op->getResults(), clonedOp->getResults()))
|
||||
mapping.map(originalResult, newResult);
|
||||
@@ -360,23 +366,26 @@ static Value emitHostCopy(IRRewriter& rewriter,
|
||||
Value sourceValue,
|
||||
int32_t hostTargetOffset,
|
||||
int32_t deviceSourceOffset,
|
||||
int32_t sizeInBytes) {
|
||||
int32_t sizeInBytes,
|
||||
OperationFolder& constantFolder) {
|
||||
Operation* anchorOp = sourceValue.getDefiningOp() ? sourceValue.getDefiningOp() : outputTensor.getDefiningOp();
|
||||
assert(anchorOp && "expected a concrete op anchor for return-path host copy constants");
|
||||
Value hostTargetOffsetValue = getOrCreateIndexConstant(constantFolder, anchorOp, hostTargetOffset);
|
||||
Value deviceSourceOffsetValue = getOrCreateIndexConstant(constantFolder, anchorOp, deviceSourceOffset);
|
||||
return PimMemCopyDevToHostOp::create(rewriter,
|
||||
loc,
|
||||
outputTensor.getType(),
|
||||
hostTargetOffsetValue,
|
||||
deviceSourceOffsetValue,
|
||||
outputTensor,
|
||||
sourceValue,
|
||||
rewriter.getI32IntegerAttr(hostTargetOffset),
|
||||
rewriter.getI32IntegerAttr(deviceSourceOffset),
|
||||
rewriter.getI32IntegerAttr(sizeInBytes))
|
||||
.getOutput();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void addReturnOutputBuffers(func::ReturnOp returnOp,
|
||||
IRRewriter& rewriter,
|
||||
SmallVectorImpl<OutputTensorFactory>& outputTensors) {
|
||||
void raptor::SpatialToPimPass::addReturnOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
|
||||
outputTensors.reserve(returnOp->getNumOperands());
|
||||
for (auto [index, returnValue] : llvm::enumerate(returnOp->getOperands())) {
|
||||
Value currentReturnValue = returnValue;
|
||||
@@ -411,70 +420,85 @@ void addReturnOutputBuffers(func::ReturnOp returnOp,
|
||||
}
|
||||
}
|
||||
|
||||
ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, ReturnPathState& state, IRRewriter& rewriter) {
|
||||
Location loc = computeOp->getLoc();
|
||||
auto yieldType = cast<TensorType>(yieldValue.getType());
|
||||
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerProducedValueReturnPath(
|
||||
Operation* producerOp, Value producedValue, Value storedValue, IRRewriter& rewriter) {
|
||||
Location loc = producerOp->getLoc();
|
||||
OperationFolder constantFolder(producerOp->getContext());
|
||||
auto storedTensorType = cast<TensorType>(storedValue.getType());
|
||||
|
||||
if (auto returnUse = analyzeReturnUse(result)) {
|
||||
Value storedValue = yieldValue;
|
||||
cloneHelperChain(yieldValue, returnUse->helperChain, rewriter, storedValue);
|
||||
if (auto returnUse = analyzeReturnUse(producedValue)) {
|
||||
Value currentStoredValue = storedValue;
|
||||
cloneHelperChain(storedValue, returnUse->helperChain, rewriter, constantFolder, currentStoredValue);
|
||||
for (Operation* op : returnUse->helperChain)
|
||||
markOpToRemove(state, op);
|
||||
markOpToRemove(op);
|
||||
|
||||
auto storedType = cast<ShapedType>(storedValue.getType());
|
||||
size_t elementSize = storedType.getElementTypeBitWidth() / 8;
|
||||
if (auto storedOp = storedValue.getDefiningOp())
|
||||
auto storedType = cast<ShapedType>(currentStoredValue.getType());
|
||||
size_t elementSize = getElementTypeSizeInBytes(storedType.getElementType());
|
||||
if (auto storedOp = currentStoredValue.getDefiningOp())
|
||||
rewriter.setInsertionPointAfter(storedOp);
|
||||
Value outputTensor = state.outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||
emitHostCopy(
|
||||
rewriter, loc, outputTensor, storedValue, 0, 0, static_cast<int32_t>(storedType.getNumElements() * elementSize));
|
||||
Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc);
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
currentStoredValue,
|
||||
0,
|
||||
0,
|
||||
static_cast<int32_t>(storedType.getNumElements() * elementSize),
|
||||
constantFolder);
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
|
||||
auto resultUses = result.getUses();
|
||||
auto resultUses = producedValue.getUses();
|
||||
if (rangeLength(resultUses) == 1) {
|
||||
OpOperand& resultUse = *resultUses.begin();
|
||||
Operation* resultUser = resultUse.getOwner();
|
||||
|
||||
if (isa<func::ReturnOp>(resultUser)) {
|
||||
size_t resultIndexInReturn = resultUse.getOperandNumber();
|
||||
size_t elementSize = yieldType.getElementType().getIntOrFloatBitWidth() / 8;
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
Value outputTensor = state.outputTensors[resultIndexInReturn](rewriter, loc);
|
||||
emitHostCopy(
|
||||
rewriter, loc, outputTensor, yieldValue, 0, 0, static_cast<int32_t>(yieldType.getNumElements() * elementSize));
|
||||
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc);
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
storedValue,
|
||||
0,
|
||||
0,
|
||||
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
|
||||
constantFolder);
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto concatReturnUse = analyzeConcatReturnUse(result)) {
|
||||
size_t elementSize = yieldType.getElementTypeBitWidth() / 8;
|
||||
if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) {
|
||||
size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType());
|
||||
for (Operation* concatOp : concatReturnUse->concatChain)
|
||||
markOpToRemove(state, concatOp);
|
||||
markOpToRemove(concatOp);
|
||||
|
||||
if (concatReturnUse->helperChain.empty()) {
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
int64_t flatOffset = computeFlatElementIndex(concatReturnUse->sliceOffsets, outputType.getShape());
|
||||
emitHostCopy(rewriter,
|
||||
loc,
|
||||
outputTensor,
|
||||
yieldValue,
|
||||
storedValue,
|
||||
static_cast<int32_t>(flatOffset * elementSize),
|
||||
0,
|
||||
static_cast<int32_t>(yieldType.getNumElements() * elementSize));
|
||||
static_cast<int32_t>(storedTensorType.getNumElements() * elementSize),
|
||||
constantFolder);
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
|
||||
auto storedType = dyn_cast<RankedTensorType>(yieldValue.getType());
|
||||
auto storedType = dyn_cast<RankedTensorType>(storedValue.getType());
|
||||
if (!storedType) {
|
||||
computeOp.emitOpError("has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
||||
producerOp->emitOpError(
|
||||
"has an unsupported non-ranked concat-return helper yield during Spatial-to-PIM lowering");
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
}
|
||||
rewriter.setInsertionPointAfterValue(yieldValue);
|
||||
Value outputTensor = state.outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
rewriter.setInsertionPointAfterValue(storedValue);
|
||||
Value outputTensor = outputTensors[concatReturnUse->returnIndex](rewriter, loc);
|
||||
auto outputType = cast<ShapedType>(outputTensor.getType());
|
||||
for (int64_t linearIndex = 0; linearIndex < storedType.getNumElements(); ++linearIndex) {
|
||||
SmallVector<int64_t> sourceIndices = expandFlatElementIndex(linearIndex, storedType.getShape());
|
||||
@@ -484,7 +508,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
SmallVector<int64_t> destinationIndices;
|
||||
if (failed(mapIndicesThroughHelperChain(
|
||||
sourceIndices, concatReturnUse->concatShape, concatReturnUse->helperChain, destinationIndices))) {
|
||||
computeOp.emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
|
||||
producerOp->emitOpError("has an unsupported concat-return helper chain during Spatial-to-PIM lowering");
|
||||
return ReturnPathLoweringResult::Failure;
|
||||
}
|
||||
|
||||
@@ -503,7 +527,7 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
auto scalarTensorType =
|
||||
RankedTensorType::get(SmallVector<int64_t>(storedType.getRank(), 1), storedType.getElementType());
|
||||
auto elementSlice = tensor::ExtractSliceOp::create(
|
||||
rewriter, loc, scalarTensorType, yieldValue, extractOffsets, extractSizes, extractStrides);
|
||||
rewriter, loc, scalarTensorType, storedValue, extractOffsets, extractSizes, extractStrides);
|
||||
rewriter.setInsertionPointAfter(elementSlice);
|
||||
|
||||
int64_t destinationFlatOffset = computeFlatElementIndex(destinationIndices, outputType.getShape());
|
||||
@@ -513,7 +537,8 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
elementSlice.getResult(),
|
||||
static_cast<int32_t>(destinationFlatOffset * elementSize),
|
||||
0,
|
||||
static_cast<int32_t>(elementSize));
|
||||
static_cast<int32_t>(elementSize),
|
||||
constantFolder);
|
||||
}
|
||||
return ReturnPathLoweringResult::Handled;
|
||||
}
|
||||
@@ -521,7 +546,12 @@ ReturnPathLoweringResult lowerComputeResultReturnPath(
|
||||
return ReturnPathLoweringResult::NotReturnPath;
|
||||
}
|
||||
|
||||
void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter, ReturnPathState& state) {
|
||||
raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::lowerComputeResultReturnPath(
|
||||
spatial::SpatCompute computeOp, OpResult result, Value yieldValue, IRRewriter& rewriter) {
|
||||
return lowerProducedValueReturnPath(computeOp.getOperation(), result, yieldValue, rewriter);
|
||||
}
|
||||
|
||||
void raptor::SpatialToPimPass::replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewriter) {
|
||||
auto markOwnedReturnChain = [&](Operation* op, auto&& markOwnedReturnChain) -> void {
|
||||
if (!op)
|
||||
return;
|
||||
@@ -538,13 +568,13 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
||||
|
||||
if (isReturnHelperChainOp(op)) {
|
||||
Value source = op->getOperand(0);
|
||||
markOpToRemove(state, op);
|
||||
markOpToRemove(op);
|
||||
markOwnedReturnChain(source.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto computeOp = dyn_cast<spatial::SpatCompute>(op)) {
|
||||
markOpToRemove(state, computeOp);
|
||||
markOpToRemove(computeOp);
|
||||
if (!computeOp.getInputs().empty())
|
||||
for (Value input : computeOp.getInputs())
|
||||
markOwnedReturnChain(input.getDefiningOp(), markOwnedReturnChain);
|
||||
@@ -552,23 +582,29 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<tensor::ConcatOp>(op)) {
|
||||
markOpToRemove(state, concatOp);
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getOperands())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<spatial::SpatConcatOp>(op)) {
|
||||
markOpToRemove(state, concatOp);
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getInputs())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto concatOp = dyn_cast<pim::PimConcatOp>(op)) {
|
||||
markOpToRemove(state, concatOp);
|
||||
markOpToRemove(concatOp);
|
||||
for (Value operand : concatOp.getInputs())
|
||||
markOwnedReturnChain(operand.getDefiningOp(), markOwnedReturnChain);
|
||||
return;
|
||||
}
|
||||
|
||||
if (auto receiveOp = dyn_cast<spatial::SpatChannelReceiveOp>(op)) {
|
||||
markOpToRemove(receiveOp);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -578,7 +614,7 @@ void replaceReturnWithOutputBuffers(func::ReturnOp returnOp, IRRewriter& rewrite
|
||||
size_t orderWithinReturn = it.index();
|
||||
Operation* returnOperand = it.value().getDefiningOp();
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
Value outputTensor = state.outputTensors[orderWithinReturn](rewriter, loc);
|
||||
Value outputTensor = outputTensors[orderWithinReturn](rewriter, loc);
|
||||
rewriter.modifyOpInPlace(returnOp, [&] { returnOp.setOperand(orderWithinReturn, outputTensor); });
|
||||
markOwnedReturnChain(returnOperand, markOwnedReturnChain);
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user